Implementing Sum Check protocol in Rust
This post is going go into reading the Sum Check protocol from the book and discussing a naive implementation of the described protocol in Rust. Again as in previous post the main goal of the post is going to be so to speak "gluing" together the math text and the implementation in software.
The Sum-Check protocol
Let's take a look at what Sum-Check protocol is and how it can be implemented. From the book:
Suppose we are given a $\nu$-variate polynomial $g$ defined over a finite field $\mathbb{F}$. The purpose of the sum-check protocol is for prover to provide the verifier with the following sum:
$$ H := \sum_{b_1 \in \lbrace 0, 1 \rbrace } \sum_{b_2 \in \lbrace 0,1 \rbrace } \cdots \sum_{b_{\nu} \in \lbrace 0, 1 \rbrace } g(b_1,\dots,b_{\nu}). $$
Both the Verifier and the Prover can directly compute $H$ directly by evaluating $H$ by evaluating $g$ over $2^\nu$ inputs (namely, all inputs in $\lbrace 0, 1 \rbrace^\nu$). Using the Sum-Check protocol the verifier's runtime will be
$$ \mathcal{O}(\nu + [\text{the cost to evaluate } g \text{ at a single input } \mathbb{F}^\nu]) $$
For the full protocol description (actually multiple descriptions: in recursive and in iterative forms) it would be best to check the book text so here it will be truncated to the points relevant to the implementation.
The protocol happens in steps and on each step Prover sends univariate polynomials of the form
$$ \sum_{(x_{j+1},...,x_{\nu}) \in \lbrace 0,1 \rbrace^{\nu - j}} g(r_1,\dots,r_{j-1},X_j,x_{j+1},\dots,x_{\nu}) $$
As you can see what this does it turns the polynomial multivariate polynomial $H$ into a univariate polynomial of variable $X_j$ by
- fixing variables $X_1,\dots,X_{j-1}$ to constant values $r_1,\dots,r_{j - 1}$ and
- summing the resulting multivariate polynomial over a Boolean hypercube of values of $X_{j+1},\dots,X_{\nu}$
Implementing the building blocks in Rust
The above vague description of the Sum Check protocol as well as the original description from the book give us several ideas which building blocks may go into the final implementation. As before the arkworks-rs framework is going to be used. It is worth noting that this framework has an implementation of the Sum Check protocol.
Boolean hypercube of $\mathbb{F}^n$ in Rust
The evaluation of $H$ as well as evaluations of the polynomials of each step of the protocol involve iterating over a Boolean hypercube ( $ \forall x \in \lbrace 0, 1 \rbrace ^n$) of some size $n$. It would be convenient to have an utility iterator to do exactly that.
For any software developer a quick example of a hypercube would immediately come to mind: all
possible values of some fixed-size integer variable in binary form. As such for u8
such values
start from 0
and end with 255
or 0..0
and 1..1
respectively. Everyone knows the
asymptotics of the number of such possible values depending on the size: at the moment
common operating systems use u64
variables can address $2^{64}$ bytes of memory which is
more than enough for modern machines.
If we could iterate over the values of a variable of fixed size and then .map()
these
values from bits to ${0, 1} \in \mathbb{F}$ that would be the easiest thing to do.
Crate bitvec has a primitive that would allow us
to do just that.
Lets create a struct that knows the dimensionality of the hypercube it needs to iterate over and also the current position within this cube:
pub struct BooleanHypercube<F: Field> {
n: u32,
current: u64,
__f: PhantomData<F>,
}
None that this struct is generic over the type of the field $\mathbb{F}$.
Now to the Iterator
implementation:
impl<F: Field> Iterator for BooleanHypercube<F> {
type Item = Vec<F>;
fn next(&mut self) -> Option<Self::Item> {
if self.current == 2u64.pow(self.n) {
None
} else {
let vec = self.current.to_le_bytes();
let s: &BitSlice<u8> = BitSlice::try_from_slice(&vec).unwrap();
self.current += 1;
Some(
s.iter()
.take(self.n as usize)
.map(|f| match *f {
false => F::zero(),
true => F::one(),
})
.collect(),
)
}
}
}
As you can see the current
counter is turned into a Little Endian [u8]
slice of bytes
and the bitvec
's BitSlice
is used to iterate over the bits of the value.
This iterator will always contain the fixed number of bytes since there is a fixed number of
bytes in u32
so the first n
have to be .take()
-en. Then the values are mapped from
bool
values to $\lbrace 0, 1 \rbrace \in \mathbb{F}$ and collected into a final vector
that is returned to the caller.
Partially computing multivariate polynomials into a univariate ones
Another part of the protocol is the reduction of a multivariate polynomial into a univariate one by fixing all but one variables. In other words, for instance by fixing $r_1,\dots,r_{j-1}$ and $x_{j+1},\dots,x_\nu$ a multivariate polynomial $g(X_1,\dots,X_\nu)$ is going to be reduced to a univariate one $g(r_1,\dots,r_{j-1},X_j,x_{j+1},\dots,x_\nu)$. This can be done by substituting variables with their fixed values in every term of the original multivariate polynomial.
The trait DenseMVPolynomial<F: Field>
describes the interface of the multivariate
polynomials in ark_poly
. I allows to get the terms of the underlying polynomial
by calling .terms()
on it which would return a sequence of type (F, Self::Term)
.
pub trait DenseMVPolynomial<F: Field>: Polynomial<F> {
type Term: multivariate::Term;
fn terms(&self) -> &[(F, Self::Term)];
The associated type Term
on this trait is bound by the multivariate::Term
trait
that requires it to have an evaluate
method for evaluating at some point.
pub trait Term {
fn evaluate<F: Field>(&self, p: &[F]) -> F;
So this gives us the idea of the implementation: evaluate each term in our polynomial $g$ at a point $(r_1,\dots,r_{j-1},1,x_{j+1},\dots,x_\nu)$, multiply this evaluation by $X_j^t$ where $t$ is the degree of $X_j$ in this term, add the results together.
For example consider this polynomial $g$ over $\mathbb{F}_5$:
$$ g(X_1, X_2, X_3) := X_1 X_2^2 + X_3 $$
to turn it into the univariate over $X_2$ with fixed $X_1 = 2$, $X_3 = 1$:
$$ g(2, X_2, 1) = \underbrace{ 2 }_{X_1 X_2^2 \text { evaluated at } \lbrace 2, 1, 1 \rbrace } X_2^2 + \underbrace{1}_{X_3 \text{ evaluated at } \lbrace 2, 1, 1 \rbrace } \equiv 2X_2^2 + 1 $$
So with that the code doing that is following:
fn to_univariate_polynomial_at_x_j<F: Field, P: DenseMVPolynomial<F>>(
p: &P,
i: usize,
at: &[F],
) -> univariate::SparsePolynomial<F> {
let mut res = univariate::SparsePolynomial::zero();
let mut at_temp = at.to_vec();
at_temp[i] = F::one();
for (coeff, term) in p.terms() {
let eval = term.evaluate(&at_temp);
let power = match term
.vars()
.iter()
.zip(term.powers().iter())
.find(|(&v, _)| v == i)
{
Some((_, p)) => *p,
None => 0,
};
let new_coeff = *coeff * eval;
res += &univariate::SparsePolynomial::from_coefficients_slice(&[(power, new_coeff)]);
}
res
}
Computing Prover's $j$-th round univariate polynomials
Recall that at each round Prover has to send Verifier a univariate polynomial that is claimed to be
$$ \sum_{(x_{j+1},...,x_{\nu}) \in \lbrace 0,1 \rbrace^{\nu - j}} g(r_1,\dots,r_{j-1},X_j,x_{j+1},\dots,x_{\nu}) $$
For some fixed values $r_1,\dots,r_{j-1}$ that Prover has previously received from the Verifier. Using the building block from the previous section the implementation is straightforward:
fn multivariate_to_univariate_with_fixed_vars<F: Field, P: DenseMVPolynomial<F>>(
g: &P,
r: &[F],
j: usize,
) -> univariate::SparsePolynomial<F> {
let mut res = univariate::SparsePolynomial::<F>::zero();
// A Boolean hypercube over variables X_{j+1}...X_{n}.
for x_point in BooleanHypercube::new((g.num_vars() - j - 1) as u32) {
// [r_1,...,r_{j-1},1,X_{j+1},...,X_n]
let mut point = r.to_vec();
point.push(F::one());
point.extend(x_point.into_iter());
let r = to_univariate_polynomial_at_x_j(g, j, &point);
res += &r;
}
res
}
Implementing the Prover
The Prover works with a polynomial $g$, the value $C_1$ that it claims to be the true answer and a set of random values $r_j$ that it receives from the Verifier at each round of the protocol. It would make sense to unite these in a state of the Prover:
pub struct Prover<F: Field> {
g: multivariate::SparsePolynomial<F, multivariate::SparseTerm>,
c_1: F,
r: Vec<F>,
}
The description of the $j$-th round of the Prover given above can now be implemented:
impl<F: Field> Prover<F> {
pub fn round(&mut self, r_prev: F, j: usize) -> univariate::SparsePolynomial<F> {
if j != 0 {
self.r.push(r_prev);
}
multivariate_to_univariate_with_fixed_vars(&self.g, &self.r, j)
}
}
Implementing the Verifier
The Verifier on each round but the final one outputs a random value $r_j$ that is sent
to the Prover. In the final round Verifier outputs the result of the verification process.
This can be described as a type with an enum
:
pub enum VerifierRoundResult<F: Field> {
JthRound(F),
FinalRound(bool),
}
Verifier has to know
- the number of variables $n$ of the polynomial
- the value $C_1$ claimed to be the true answer by the Prover
- polynomials $g_i,...,g_j$ sent by the Prover at each round
- random values that were picked by the Verifier at each round
- a polynomial $g$ for the oracle access
Uniting these into one struct:
pub struct Verifier<F: Field> {
n: usize,
c_1: F,
g_part: Vec<univariate::SparsePolynomial<F>>,
r: Vec<F>,
g: multivariate::SparsePolynomial<F, multivariate::SparseTerm>,
}
Finally, we have to implement a run of the single round of the Verifier. The book describes Verifier rounds:
First round:
At the first round Verifier checks that $C_1 = g_1(0) + g_1(1)$, i.e. the verifier checks that $g_1$ and the claimed answer are consistent with Equation.
$j$-th round:
The Verifier compares the two most recent polynomials by checking $g_{j-1}(r_{j-1}) = g_j(0) + g_j(1)$ and rejecting otherwise.
Final round:
The prover has sent g_n(X_n) which is claimed to be $g(r_1,\dots,r_{n-1},X_n)$. Verifier now checks that $g_n(r_n) = g(r_1,...,r_n)$. If this check succeeds as well as all the previous checks then the Verifier is convinced that $H = g_1(0) + g_1(1)$.
Since the Verifier does custom things on round 1 and $n$ the code will be less compact:
impl<F: Field> Verifier<F> {
pub fn round<R: Rng>(
&mut self,
g_j: univariate::SparsePolynomial<F>,
rng: &mut R,
) -> Result<VerifierRoundResult<F>, ()> {
let r_j = F::rand(rng);
if self.r.is_empty() {
// First Round
if self.c_1 != g_j.evaluate(&F::zero()) + g_j.evaluate(&F::one()) {
Err(())
} else {
self.g_part.push(g_j);
self.r.push(r_j);
Ok(VerifierRoundResult::JthRound(r_j))
}
} else if self.r.len() == (self.n - 1) {
// Last round
self.r.push(r_j);
Ok(VerifierRoundResult::FinalRound(
g_j.evaluate(&r_j) == self.g.evaluate(&self.r),
))
} else {
// j-th round
let g_jprev = self.g_part.last().unwrap();
let r_jprev = self.r.last().unwrap();
if g_jprev.evaluate(r_jprev) != (g_j.evaluate(&F::zero()) + g_j.evaluate(&F::one())) {
return Err(());
}
self.g_part.push(g_j);
self.r.push(r_j);
Ok(VerifierRoundResult::JthRound(r_j))
}
}
}
That would be it for today
The code for this post is published here.
Changelog
26-08-2022
- Fixed typos and things pointed out in review by Thor.