Evaluating multilinear extensions with Rust

Hi there! This is a first post in a series where I am planning to take notes on Proofs, Arguments and Zero-Knowledge as I am reading it. I write code in Rust for a living so I am going to use this language to implement algorithms and protocols from the book.

Multilinear Extensions

Preliminaries

Chapter 3 of the book introduces polynomial extensions for multivariate functions:

Let $\mathbb{F}$ be any finite field, and let $f : \lbrace 0, 1 \rbrace ^n \rightarrow \mathbb{F}$ be any function mapping the $\nu$-dimensional Boolean hypercube to $\mathbb{F}$. A $v$-variate polynomial $g$ over $\mathbb{F}$ is said to be an extension of $f$ if $g$ agrees with $f$ at all Boolean-valued inputs, i.e. $g(x) = f(x) \space \forall x \in \lbrace 0, 1 \rbrace ^\nu$

Then the multilinear polynomials are defined:

Definition 3.4. A multivariate polynomial $g$ is multilinear if the degree of the polynomial at each variable is at most one.

A following fact is introduced:

Fact 3.5. Any function $f : \lbrace 0, 1 \rbrace ^\nu \rightarrow \mathbb{F}$ has a unique multilinear extension (MLE) over $\mathbb{F}$, and we reserve the notation $\tilde{f}$ for this special extension of $f$.

And a lemma:

Lemma 3.6. (Lagrange interpolation of multilinear polynomials). Let $f : \lbrace 0, 1 \rbrace ^\nu \rightarrow \mathbb{F}$ be any function. Then the following multilinear polynomial $\tilde{f}$ extends $f$:

$$ \begin{equation} \tilde{f}(x_1,\dots,x_\nu) = \sum_{w \in \lbrace 0, 1 \rbrace ^\nu} f(w) \cdot \chi_w(x_1,\dots,x_\nu) \tag{3.1} \end{equation} $$

where, for any $w = (w_1,\dots,w_\nu)$:

$$ \begin{equation} \chi_w(x_1,\dots,x_\nu) := \prod_{i=1}^{\nu}(x_i w_i + (1 - x_i)(1 - w_i)). \tag{3.2} \end{equation} $$

which gives a way to evaluate $\tilde{f}$.

Algorithms for evaluating the multilinear extension of $f$

There are two ways to efficiently compute $\tilde{f}$ at any point $r \in \mathbb{F}^\nu$ if the values $f(w)$ are given for all $n = 2^\nu$ Boolean vectors $w \in \lbrace 0, 1 \rbrace ^\nu$:

Lemma 3.7. Gives a way to compute the right side of equation (3.1) incrementally from stream by initializing $\tilde{f}(r) \leftarrow 0$, and processing each update $(w, f(w))$ via:

$$ \tilde{f}(r) \leftarrow \tilde{f}(r) + f(w) \cdot \chi_w(r). $$

in $\mathcal{O}(n \log{} n)$ time and $\mathcal{O}(\log{} n)$ space.

Lemma 3.8 gives a staged way of computing $\tilde{f}(r)$ in $n$ stages where Stage $j$ constructs a memorization table $A^{(j)}$ of size $2^j$.

Example

Let's take a look at the example from the book: a function $f$ mapping $\lbrace 0, 1 \rbrace ^\nu$ to $\mathbb{F}_5$.

All evaluations of a function $f$ mapping $ \lbrace 0, 1 \rbrace ^2$ to the field $\mathbb{F}_5$:

$$ \begin{array}{c|c|c|} & 0 & 1 \\ \hline 0 & 1 & 2 \\ \hline 1 & 1 & 4 \\ \hline \end{array} $$

All evaluations of the multilinear extension, $\tilde{f}$ of $f$ over $\mathbb{F}_5$:

$$ \begin{array}{c|c|c|c|c|c|} & 0 & 1 & 2 & 3 & 4 \\ \hline 0 & 1 & 2 & 3 & 4 & 0 \\ \hline 1 & 1 & 4 & 2 & 0 & 3 \\ \hline 2 & 1 & 1 & 1 & 1 & 1 \\ \hline 3 & 1 & 3 & 0 & 2 & 4 \\ \hline 4 & 1 & 0 & 4 & 3 & 2 \\ \hline \end{array} $$

Via Lagrange Interpolation $\tilde{f}(x_1, x_2) = (1 - x_1)(1 - x_2) + 2(1 - x_1)x_2 + x_1(1 - x_2) + 4x_1x_2$

To understand the evaluation Algorithms better lets do some by-hand evaluations of $\tilde{f}$.

Evaluating $\tilde{f}(r)$ with Lemma 3.7

Let $r = \lbrace 3, 1 \rbrace $. Initialize $\tilde{f}(r) \leftarrow 0$.

Step 1.

$w = \lbrace 0, 0 \rbrace$.

$$ \chi_w(r) = (r_1 w_1 + (1 - r_1)(1 - w_1))(r_2 w_2 + (1 - r_2)(1 - w_2)) = \newline (3 \cdot 0 + (1 - 3)(1 - 0))(1 \cdot 0 + (1 - 1)(1 - 0)) = 0. $$

Update $$ \tilde{f}(r) \leftarrow \underbrace{\tilde{f}(r)}_{0} +\underbrace{f(w)}_{1} \cdot 0 $$

Step 2.

$w = \lbrace 0, 1 \rbrace $.

$$ \chi_w(r) = (3 \cdot 0 + (1 - 3)(1 - 0))(1 \cdot 1 + (1 - 1)(1 - 1)) = -2 \cdot 1 = 3. $$

Update

$$ \tilde{f}(r) \leftarrow \underbrace{\tilde{f}(r)}_{0} + \underbrace{f(w)}_{2} \cdot 3 = 1. $$

Step 3.

$w = \lbrace 1, 0 \rbrace $

$$ \chi_w(r) = (3 \cdot 1 + (1 - 3)(1 - 1))(1 \cdot 0 + (1 - 1)(1 - 0)) = 0. $$

So no update needed on this step.

Step 4.

$w = \lbrace 1, 1 \rbrace $

$$ \chi_w(r) = (3 \cdot 1 + (1 - 3)(1 - 1))(1 \cdot 1 + (1 - 1)(1 - 1)) = 3. $$

Update

$$ \tilde{f}(r) \leftarrow \underbrace{\tilde{f}(r)}_{1} + \underbrace{f(w)}_{4} \cdot 3 = 1 + 2 = 3. $$

Evaluating $\tilde{f}(r)$ with Lemma 3.8

Step 1.

$r = \lbrace 3, 1 \rbrace$

At this step a table $A^{(1)}$ of size $2$ is constructed:

$A^{(1)}[(w_1)] = (w_1 r_1 + (1 - w_1)(1 - r_1))$

$A^{(1)}[(0)] = (1 - r_1) = (1 - 3) = 3$

$A^{(1)}[(1)] = r_1 = 3$

Step 2.

At this step a table $A^{(2)}$ of size $4$ is constructed:

$A^{(2)}[(w_1, w_2)] = A^{(1)}[w_1] \cdot (w_2 r_2 + (1 - w_2)(1 - r_2))$

$A^{(2)}[(0, 0)] = 3 \cdot (0 \cdot 1 + (1 - 0)(1 - 1)) = 0$

$A^{(2)}[(0, 1)] = 3 \cdot (1 \cdot 1 + (1 - 1)(1 - 1)) = 3$

$A^{(2)}[(1, 0)] = 3 \cdot (0 \cdot 1 + (1 - 0)(1 - 1)) = 0$

$A^{(2)}[(1, 1)] = 3 \cdot (1 \cdot 1 + (1 - 1)(1 - 1)) = 3$

$\tilde{f}({3, 1}) = 0 \cdot 1 + 3 \cdot 2 + 0 \cdot 1 + 3 \cdot 4 = 6 + 12 = 1 + 2 = 3$

Implementing multilinear extension evaluations in Rust

To implement these algorithms in Rust the arkworks-rs framework is going to be used. Specifically we are going to need the finite field and polynomial crates (for the reasons of the newer API at the time of writing this post the git dependencies are used):

[dependencies]
ark-ff = { git = "https://github.com/arkworks-rs/algebra" }
ark-poly = { git = "https://github.com/arkworks-rs/algebra" }

First lets take a look at the Field trait ark-ff provides. Field is a subtrait of a number of traits that make sense for any Field type such as Add and Mul as well as Zero and One traits that require the type to have the said $1$ and $0$ elements. It would make sense to make our implementations generic over a type that implements this Field trait.

The first building blocks used in both algorithms is computation of base Lagrange polynomial $\chi_w(x_1,\dots,x_\nu)$. (3.1) gives us a straightforward way to implement this:

fn lagrange_basis_poly_at<F: Field>(x: &[F], w: &[F]) -> Option<F> {
    if x.len() != x.len() {
        None
    } else {
        let res = x.iter().zip(w.iter()).fold(F::one(), |acc, (&x_i, &w_i)| {
            acc * (x_i * w_i + (F::one() - x_i) * (F::one() - w_i))
        });

        Some(res)
    }
}

This code calculates $\chi_w(x_1,\dots,x_\nu)$ in one streaming pass. There is also a sanity check that $x$ and $w$ are of the same length.

Now let's move to implementing the evaluation algorithms themselves. Both of the algorithms have the same inputs: evaluations of $f(w)$ on all $2^n$ Boolean vectors $w$. and a fixed point $r$ at which the algorithm needs to compute $\tilde{f}(r)$.

To save the space the evaluations $w$ can be passed to functions without the Boolean vectors themselves, instead the binary form of the index $j$ of $w_j$ value in the vector is used as one. For instance in the vector $\lbrace 1, 2, 3, 4 \rbrace $ value $3$ is at index $2$ that has a binary form $[1, 0]$ which would correspond to the vector $w = \lbrace 1, 0 \rbrace $.

fn cti_multilinear_from_evaluations<F: Field>(evals: &[F], r: &[F]) -> F {
    let mut res = F::zero();
    let o = 0u32;

    for (i, eval) in evals.iter().enumerate() {
        let mut w = Vec::with_capacity(r.len());

        let len = r.len();

        for j in (0..len).rev() {
            let bit = 2_usize.pow(j as u32);

            let w_j = if i & bit == 0 { F::zero() } else { F::one() };
            w.push(w_j);
        }

        res += *eval * lagrange_basis_poly_at(r, &w).unwrap();
    }

    res
}
fn vsbw_multilinear_from_evaluations<F: Field>(evals: &[F], r: &[F]) -> F {
    let mut eval_table = vec![F::one()];

    for r_j in r {
        let mut eval_table_new = Vec::with_capacity(eval_table.len() * 2);

        for eval in eval_table.into_iter() {
            eval_table_new.push(eval * (F::one() - r_j));
            eval_table_new.push(eval * r_j);
        }

        eval_table = eval_table_new;
    }

    eval_table
        .into_iter()
        .zip(evals.iter())
        .fold(F::zero(), |acc, (w_j, p_j)| acc + w_j * p_j)
}

Testing on $\mathbb{F}_5$

Not going too deep into the details of ark-ff interface the $F_5$ field may be defined as follows:

#[derive(MontConfig)]
#[modulus = "5"]
#[generator = "2"]
struct FrConfig;

type Fr = Fp64<MontBackend<FrConfig, 1>>;

Working with this type would require explicit conversion to and from u32 types. The evaluations of $f$ from the above example:

let evals: Vec<_> = [1u32, 2, 1, 4]
    .iter()
    .map(|&f| Fr::from_bigint(f.into()).unwrap())
    .collect();

Finally, lets construct the matrix of evaluations of $\tilde{f}(r)$ on $\mathbb{F}_5^2$:

for i in 0u32..5 {
    let mut line = Vec::with_capacity(5);
    for j in 0u32..5 {
        let f_r = cti_multilinear_from_evaluations(
            &evals,
            &[
                Fr::from_bigint(i.into()).unwrap(),
                Fr::from_bigint(j.into()).unwrap(),
            ],
        );
        line.push(f_r.into_bigint().as_ref()[0]);
    }
    println!("{:?}", line);
}

(The same piece of code is called for the second algorithm).

When run this code will compute the following values:

[1, 2, 3, 4, 0]
[1, 4, 2, 0, 3]
[1, 1, 1, 1, 1]
[1, 3, 0, 2, 4]
[1, 0, 4, 3, 2]

[1, 2, 3, 4, 0]
[1, 4, 2, 0, 3]
[1, 1, 1, 1, 1]
[1, 3, 0, 2, 4]
[1, 0, 4, 3, 2]

Which are the same as in the example above. Whole code for this post is available at github repo.