Implementing Sum Check protocol in Rust

This post is going go into reading the Sum Check protocol from the book and discussing a naive implementation of the described protocol in Rust. Again as in previous post the main goal of the post is going to be so to speak "gluing" together the math text and the implementation in software.

The Sum-Check protocol

Let's take a look at what Sum-Check protocol is and how it can be implemented. From the book:

Suppose we are given a $\nu$-variate polynomial $g$ defined over a finite field $\mathbb{F}$. The purpose of the sum-check protocol is for prover to provide the verifier with the following sum:

$$ H := \sum_{b_1 \in \lbrace 0, 1 \rbrace } \sum_{b_2 \in \lbrace 0,1 \rbrace } \cdots \sum_{b_{\nu} \in \lbrace 0, 1 \rbrace } g(b_1,\dots,b_{\nu}). $$

Both the Verifier and the Prover can directly compute $H$ directly by evaluating $H$ by evaluating $g$ over $2^\nu$ inputs (namely, all inputs in $\lbrace 0, 1 \rbrace^\nu$). Using the Sum-Check protocol the verifier's runtime will be

$$ \mathcal{O}(\nu + [\text{the cost to evaluate } g \text{ at a single input } \mathbb{F}^\nu]) $$

For the full protocol description (actually multiple descriptions: in recursive and in iterative forms) it would be best to check the book text so here it will be truncated to the points relevant to the implementation.

The protocol happens in steps and on each step Prover sends univariate polynomials of the form

$$ \sum_{(x_{j+1},...,x_{\nu}) \in \lbrace 0,1 \rbrace^{\nu - j}} g(r_1,\dots,r_{j-1},X_j,x_{j+1},\dots,x_{\nu}) $$

As you can see what this does it turns the polynomial multivariate polynomial $H$ into a univariate polynomial of variable $X_j$ by

Implementing the building blocks in Rust

The above vague description of the Sum Check protocol as well as the original description from the book give us several ideas which building blocks may go into the final implementation. As before the arkworks-rs framework is going to be used. It is worth noting that this framework has an implementation of the Sum Check protocol.

Boolean hypercube of $\mathbb{F}^n$ in Rust

The evaluation of $H$ as well as evaluations of the polynomials of each step of the protocol involve iterating over a Boolean hypercube ( $ \forall x \in \lbrace 0, 1 \rbrace ^n$) of some size $n$. It would be convenient to have an utility iterator to do exactly that.

For any software developer a quick example of a hypercube would immediately come to mind: all possible values of some fixed-size integer variable in binary form. As such for u8 such values start from 0 and end with 255 or 0..0 and 1..1 respectively. Everyone knows the asymptotics of the number of such possible values depending on the size: at the moment common operating systems use u64 variables can address $2^{64}$ bytes of memory which is more than enough for modern machines.

If we could iterate over the values of a variable of fixed size and then .map() these values from bits to ${0, 1} \in \mathbb{F}$ that would be the easiest thing to do. Crate bitvec has a primitive that would allow us to do just that.

Lets create a struct that knows the dimensionality of the hypercube it needs to iterate over and also the current position within this cube:

pub struct BooleanHypercube<F: Field> {
    n: u32,
    current: u64,
    __f: PhantomData<F>,
}

None that this struct is generic over the type of the field $\mathbb{F}$.

Now to the Iterator implementation:

impl<F: Field> Iterator for BooleanHypercube<F> {
    type Item = Vec<F>;

    fn next(&mut self) -> Option<Self::Item> {
        if self.current == 2u64.pow(self.n) {
            None
        } else {
            let vec = self.current.to_le_bytes();
            let s: &BitSlice<u8> = BitSlice::try_from_slice(&vec).unwrap();
            self.current += 1;

            Some(
                s.iter()
                    .take(self.n as usize)
                    .map(|f| match *f {
                        false => F::zero(),
                        true => F::one(),
                    })
                    .collect(),
            )
        }
    }
}

As you can see the current counter is turned into a Little Endian [u8] slice of bytes and the bitvec's BitSlice is used to iterate over the bits of the value. This iterator will always contain the fixed number of bytes since there is a fixed number of bytes in u32 so the first n have to be .take()-en. Then the values are mapped from bool values to $\lbrace 0, 1 \rbrace \in \mathbb{F}$ and collected into a final vector that is returned to the caller.

Partially computing multivariate polynomials into a univariate ones

Another part of the protocol is the reduction of a multivariate polynomial into a univariate one by fixing all but one variables. In other words, for instance by fixing $r_1,\dots,r_{j-1}$ and $x_{j+1},\dots,x_\nu$ a multivariate polynomial $g(X_1,\dots,X_\nu)$ is going to be reduced to a univariate one $g(r_1,\dots,r_{j-1},X_j,x_{j+1},\dots,x_\nu)$. This can be done by substituting variables with their fixed values in every term of the original multivariate polynomial.

The trait DenseMVPolynomial<F: Field> describes the interface of the multivariate polynomials in ark_poly. I allows to get the terms of the underlying polynomial by calling .terms() on it which would return a sequence of type (F, Self::Term).

pub trait DenseMVPolynomial<F: Field>: Polynomial<F> {
    type Term: multivariate::Term;

    fn terms(&self) -> &[(F, Self::Term)];

The associated type Term on this trait is bound by the multivariate::Term trait that requires it to have an evaluate method for evaluating at some point.

pub trait Term {
    fn evaluate<F: Field>(&self, p: &[F]) -> F;

So this gives us the idea of the implementation: evaluate each term in our polynomial $g$ at a point $(r_1,\dots,r_{j-1},1,x_{j+1},\dots,x_\nu)$, multiply this evaluation by $X_j^t$ where $t$ is the degree of $X_j$ in this term, add the results together.

For example consider this polynomial $g$ over $\mathbb{F}_5$:

$$ g(X_1, X_2, X_3) := X_1 X_2^2 + X_3 $$

to turn it into the univariate over $X_2$ with fixed $X_1 = 2$, $X_3 = 1$:

$$ g(2, X_2, 1) = \underbrace{ 2 }_{X_1 X_2^2 \text { evaluated at } \lbrace 2, 1, 1 \rbrace } X_2^2 + \underbrace{1}_{X_3 \text{ evaluated at } \lbrace 2, 1, 1 \rbrace } \equiv 2X_2^2 + 1 $$

So with that the code doing that is following:

fn to_univariate_polynomial_at_x_j<F: Field, P: DenseMVPolynomial<F>>(
    p: &P,
    i: usize,
    at: &[F],
) -> univariate::SparsePolynomial<F> {
    let mut res = univariate::SparsePolynomial::zero();
    let mut at_temp = at.to_vec();
    at_temp[i] = F::one();

    for (coeff, term) in p.terms() {
        let eval = term.evaluate(&at_temp);
        let power = match term
            .vars()
            .iter()
            .zip(term.powers().iter())
            .find(|(&v, _)| v == i)
        {
            Some((_, p)) => *p,
            None => 0,
        };
        let new_coeff = *coeff * eval;
        res += &univariate::SparsePolynomial::from_coefficients_slice(&[(power, new_coeff)]);
    }
    res
}

Computing Prover's $j$-th round univariate polynomials

Recall that at each round Prover has to send Verifier a univariate polynomial that is claimed to be

$$ \sum_{(x_{j+1},...,x_{\nu}) \in \lbrace 0,1 \rbrace^{\nu - j}} g(r_1,\dots,r_{j-1},X_j,x_{j+1},\dots,x_{\nu}) $$

For some fixed values $r_1,\dots,r_{j-1}$ that Prover has previously received from the Verifier. Using the building block from the previous section the implementation is straightforward:

fn multivariate_to_univariate_with_fixed_vars<F: Field, P: DenseMVPolynomial<F>>(
    g: &P,
    r: &[F],
    j: usize,
) -> univariate::SparsePolynomial<F> {
    let mut res = univariate::SparsePolynomial::<F>::zero();

    // A Boolean hypercube over variables X_{j+1}...X_{n}.
    for x_point in BooleanHypercube::new((g.num_vars() - j - 1) as u32) {
        // [r_1,...,r_{j-1},1,X_{j+1},...,X_n]
        let mut point = r.to_vec();
        point.push(F::one());
        point.extend(x_point.into_iter());

        let r = to_univariate_polynomial_at_x_j(g, j, &point);
        res += &r;
    }

    res
}

Implementing the Prover

The Prover works with a polynomial $g$, the value $C_1$ that it claims to be the true answer and a set of random values $r_j$ that it receives from the Verifier at each round of the protocol. It would make sense to unite these in a state of the Prover:

pub struct Prover<F: Field> {
    g: multivariate::SparsePolynomial<F, multivariate::SparseTerm>,
    c_1: F,
    r: Vec<F>,
}

The description of the $j$-th round of the Prover given above can now be implemented:

impl<F: Field> Prover<F> {
    pub fn round(&mut self, r_prev: F, j: usize) -> univariate::SparsePolynomial<F> {
        if j != 0 {
            self.r.push(r_prev);
        }
        multivariate_to_univariate_with_fixed_vars(&self.g, &self.r, j)
    }
}

Implementing the Verifier

The Verifier on each round but the final one outputs a random value $r_j$ that is sent to the Prover. In the final round Verifier outputs the result of the verification process. This can be described as a type with an enum:

pub enum VerifierRoundResult<F: Field> {
    JthRound(F),
    FinalRound(bool),
}

Verifier has to know

  1. the number of variables $n$ of the polynomial
  2. the value $C_1$ claimed to be the true answer by the Prover
  3. polynomials $g_i,...,g_j$ sent by the Prover at each round
  4. random values that were picked by the Verifier at each round
  5. a polynomial $g$ for the oracle access

Uniting these into one struct:

pub struct Verifier<F: Field> {
    n: usize,
    c_1: F,
    g_part: Vec<univariate::SparsePolynomial<F>>,
    r: Vec<F>,
    g: multivariate::SparsePolynomial<F, multivariate::SparseTerm>,
}

Finally, we have to implement a run of the single round of the Verifier. The book describes Verifier rounds:

First round:

At the first round Verifier checks that $C_1 = g_1(0) + g_1(1)$, i.e. the verifier checks that $g_1$ and the claimed answer are consistent with Equation.

$j$-th round:

The Verifier compares the two most recent polynomials by checking $g_{j-1}(r_{j-1}) = g_j(0) + g_j(1)$ and rejecting otherwise.

Final round:

The prover has sent g_n(X_n) which is claimed to be $g(r_1,\dots,r_{n-1},X_n)$. Verifier now checks that $g_n(r_n) = g(r_1,...,r_n)$. If this check succeeds as well as all the previous checks then the Verifier is convinced that $H = g_1(0) + g_1(1)$.

Since the Verifier does custom things on round 1 and $n$ the code will be less compact:

impl<F: Field> Verifier<F> {
    pub fn round<R: Rng>(
        &mut self,
        g_j: univariate::SparsePolynomial<F>,
        rng: &mut R,
    ) -> Result<VerifierRoundResult<F>, ()> {
        let r_j = F::rand(rng);
        if self.r.is_empty() {
            // First Round
            if self.c_1 != g_j.evaluate(&F::zero()) + g_j.evaluate(&F::one()) {
                Err(())
            } else {
                self.g_part.push(g_j);
                self.r.push(r_j);

                Ok(VerifierRoundResult::JthRound(r_j))
            }
        } else if self.r.len() == (self.n - 1) {
            // Last round
            self.r.push(r_j);
            Ok(VerifierRoundResult::FinalRound(
                g_j.evaluate(&r_j) == self.g.evaluate(&self.r),
            ))
        } else {
            // j-th round
            let g_jprev = self.g_part.last().unwrap();
            let r_jprev = self.r.last().unwrap();

            if g_jprev.evaluate(r_jprev) != (g_j.evaluate(&F::zero()) + g_j.evaluate(&F::one())) {
                return Err(());
            }

            self.g_part.push(g_j);
            self.r.push(r_j);

            Ok(VerifierRoundResult::JthRound(r_j))
        }
    }
}

That would be it for today

The code for this post is published here.


Changelog

26-08-2022