openvm_sha256_air/
air.rs

1use std::{array, borrow::Borrow, cmp::max, iter::once};
2
3use openvm_circuit_primitives::{
4    bitwise_op_lookup::BitwiseOperationLookupBus,
5    encoder::Encoder,
6    utils::{not, select},
7    SubAir,
8};
9use openvm_stark_backend::{
10    interaction::{BusIndex, InteractionBuilder, PermutationCheckBus},
11    p3_air::{AirBuilder, BaseAir},
12    p3_field::{Field, FieldAlgebra},
13    p3_matrix::Matrix,
14};
15
16use super::{
17    big_sig0_field, big_sig1_field, ch_field, compose, maj_field, small_sig0_field,
18    small_sig1_field, u32_into_limbs, Sha256DigestCols, Sha256RoundCols, SHA256_DIGEST_WIDTH,
19    SHA256_H, SHA256_HASH_WORDS, SHA256_K, SHA256_ROUNDS_PER_ROW, SHA256_ROUND_WIDTH,
20    SHA256_WORD_BITS, SHA256_WORD_U16S, SHA256_WORD_U8S,
21};
22use crate::constraint_word_addition;
23
24/// Expects the message to be padded to a multiple of 512 bits
25#[derive(Clone, Debug)]
26pub struct Sha256Air {
27    pub bitwise_lookup_bus: BitwiseOperationLookupBus,
28    pub row_idx_encoder: Encoder,
29    /// Internal bus for self-interactions in this AIR.
30    bus: PermutationCheckBus,
31}
32
33impl Sha256Air {
34    pub fn new(bitwise_lookup_bus: BitwiseOperationLookupBus, self_bus_idx: BusIndex) -> Self {
35        Self {
36            bitwise_lookup_bus,
37            row_idx_encoder: Encoder::new(18, 2, false),
38            bus: PermutationCheckBus::new(self_bus_idx),
39        }
40    }
41}
42
43impl<F> BaseAir<F> for Sha256Air {
44    fn width(&self) -> usize {
45        max(
46            Sha256RoundCols::<F>::width(),
47            Sha256DigestCols::<F>::width(),
48        )
49    }
50}
51
52impl<AB: InteractionBuilder> SubAir<AB> for Sha256Air {
53    /// The start column for the sub-air to use
54    type AirContext<'a>
55        = usize
56    where
57        Self: 'a,
58        AB: 'a,
59        <AB as AirBuilder>::Var: 'a,
60        <AB as AirBuilder>::Expr: 'a;
61
62    fn eval<'a>(&'a self, builder: &'a mut AB, start_col: Self::AirContext<'a>)
63    where
64        <AB as AirBuilder>::Var: 'a,
65        <AB as AirBuilder>::Expr: 'a,
66    {
67        self.eval_row(builder, start_col);
68        self.eval_transitions(builder, start_col);
69    }
70}
71
72impl Sha256Air {
73    /// Implements the single row constraints (i.e. imposes constraints only on local)
74    /// Implements some sanity constraints on the row index, flags, and work variables
75    fn eval_row<AB: InteractionBuilder>(&self, builder: &mut AB, start_col: usize) {
76        let main = builder.main();
77        let local = main.row_slice(0);
78
79        // Doesn't matter which column struct we use here as we are only interested in the common
80        // columns
81        let local_cols: &Sha256DigestCols<AB::Var> =
82            local[start_col..start_col + SHA256_DIGEST_WIDTH].borrow();
83        let flags = &local_cols.flags;
84        builder.assert_bool(flags.is_round_row);
85        builder.assert_bool(flags.is_first_4_rows);
86        builder.assert_bool(flags.is_digest_row);
87        builder.assert_bool(flags.is_round_row + flags.is_digest_row);
88        builder.assert_bool(flags.is_last_block);
89
90        self.row_idx_encoder
91            .eval(builder, &local_cols.flags.row_idx);
92        builder.assert_one(
93            self.row_idx_encoder
94                .contains_flag_range::<AB>(&local_cols.flags.row_idx, 0..=17),
95        );
96        builder.assert_eq(
97            self.row_idx_encoder
98                .contains_flag_range::<AB>(&local_cols.flags.row_idx, 0..=3),
99            flags.is_first_4_rows,
100        );
101        builder.assert_eq(
102            self.row_idx_encoder
103                .contains_flag_range::<AB>(&local_cols.flags.row_idx, 0..=15),
104            flags.is_round_row,
105        );
106        builder.assert_eq(
107            self.row_idx_encoder
108                .contains_flag::<AB>(&local_cols.flags.row_idx, &[16]),
109            flags.is_digest_row,
110        );
111        // If padding row we want the row_idx to be 17
112        builder.assert_eq(
113            self.row_idx_encoder
114                .contains_flag::<AB>(&local_cols.flags.row_idx, &[17]),
115            flags.is_padding_row(),
116        );
117
118        // Constrain a, e, being composed of bits: we make sure a and e are always in the same place
119        // in the trace matrix Note: this has to be true for every row, even padding rows
120        for i in 0..SHA256_ROUNDS_PER_ROW {
121            for j in 0..SHA256_WORD_BITS {
122                builder.assert_bool(local_cols.hash.a[i][j]);
123                builder.assert_bool(local_cols.hash.e[i][j]);
124            }
125        }
126    }
127
128    /// Implements constraints for a digest row that ensure proper state transitions between blocks
129    /// This validates that:
130    /// The work variables are correctly initialized for the next message block
131    /// For the last message block, the initial state matches SHA256_H constants
132    fn eval_digest_row<AB: InteractionBuilder>(
133        &self,
134        builder: &mut AB,
135        local: &Sha256RoundCols<AB::Var>,
136        next: &Sha256DigestCols<AB::Var>,
137    ) {
138        // Check that if this is the last row of a message or an inpadding row, the hash should be
139        // the [SHA256_H]
140        for i in 0..SHA256_ROUNDS_PER_ROW {
141            let a = next.hash.a[i].map(|x| x.into());
142            let e = next.hash.e[i].map(|x| x.into());
143            for j in 0..SHA256_WORD_U16S {
144                let a_limb = compose::<AB::Expr>(&a[j * 16..(j + 1) * 16], 1);
145                let e_limb = compose::<AB::Expr>(&e[j * 16..(j + 1) * 16], 1);
146
147                // If it is a padding row or the last row of a message, the `hash` should be the
148                // [SHA256_H]
149                builder
150                    .when(
151                        next.flags.is_padding_row()
152                            + next.flags.is_last_block * next.flags.is_digest_row,
153                    )
154                    .assert_eq(
155                        a_limb,
156                        AB::Expr::from_canonical_u32(
157                            u32_into_limbs::<2>(SHA256_H[SHA256_ROUNDS_PER_ROW - i - 1])[j],
158                        ),
159                    );
160
161                builder
162                    .when(
163                        next.flags.is_padding_row()
164                            + next.flags.is_last_block * next.flags.is_digest_row,
165                    )
166                    .assert_eq(
167                        e_limb,
168                        AB::Expr::from_canonical_u32(
169                            u32_into_limbs::<2>(SHA256_H[SHA256_ROUNDS_PER_ROW - i + 3])[j],
170                        ),
171                    );
172            }
173        }
174
175        // Check if last row of a non-last block, the `hash` should be equal to the final hash of
176        // the current block
177        for i in 0..SHA256_ROUNDS_PER_ROW {
178            let prev_a = next.hash.a[i].map(|x| x.into());
179            let prev_e = next.hash.e[i].map(|x| x.into());
180            let cur_a = next.final_hash[SHA256_ROUNDS_PER_ROW - i - 1].map(|x| x.into());
181
182            let cur_e = next.final_hash[SHA256_ROUNDS_PER_ROW - i + 3].map(|x| x.into());
183            for j in 0..SHA256_WORD_U8S {
184                let prev_a_limb = compose::<AB::Expr>(&prev_a[j * 8..(j + 1) * 8], 1);
185                let prev_e_limb = compose::<AB::Expr>(&prev_e[j * 8..(j + 1) * 8], 1);
186
187                builder
188                    .when(not(next.flags.is_last_block) * next.flags.is_digest_row)
189                    .assert_eq(prev_a_limb, cur_a[j].clone());
190
191                builder
192                    .when(not(next.flags.is_last_block) * next.flags.is_digest_row)
193                    .assert_eq(prev_e_limb, cur_e[j].clone());
194            }
195        }
196
197        // Assert that the previous hash + work vars == final hash.
198        // That is, `next.prev_hash[i] + local.work_vars[i] == next.final_hash[i]`
199        // where addition is done modulo 2^32
200        for i in 0..SHA256_HASH_WORDS {
201            let mut carry = AB::Expr::ZERO;
202            for j in 0..SHA256_WORD_U16S {
203                let work_var_limb = if i < SHA256_ROUNDS_PER_ROW {
204                    compose::<AB::Expr>(
205                        &local.work_vars.a[SHA256_ROUNDS_PER_ROW - 1 - i][j * 16..(j + 1) * 16],
206                        1,
207                    )
208                } else {
209                    compose::<AB::Expr>(
210                        &local.work_vars.e[SHA256_ROUNDS_PER_ROW + 3 - i][j * 16..(j + 1) * 16],
211                        1,
212                    )
213                };
214                let final_hash_limb =
215                    compose::<AB::Expr>(&next.final_hash[i][j * 2..(j + 1) * 2], 8);
216
217                carry = AB::Expr::from(AB::F::from_canonical_u32(1 << 16).inverse())
218                    * (next.prev_hash[i][j] + work_var_limb + carry - final_hash_limb);
219                builder
220                    .when(next.flags.is_digest_row)
221                    .assert_bool(carry.clone());
222            }
223            // constrain the final hash limbs two at a time since we can do two checks per
224            // interaction
225            for chunk in next.final_hash[i].chunks(2) {
226                self.bitwise_lookup_bus
227                    .send_range(chunk[0], chunk[1])
228                    .eval(builder, next.flags.is_digest_row);
229            }
230        }
231    }
232
233    fn eval_transitions<AB: InteractionBuilder>(&self, builder: &mut AB, start_col: usize) {
234        let main = builder.main();
235        let local = main.row_slice(0);
236        let next = main.row_slice(1);
237
238        // Doesn't matter what column structs we use here
239        let local_cols: &Sha256RoundCols<AB::Var> =
240            local[start_col..start_col + SHA256_ROUND_WIDTH].borrow();
241        let next_cols: &Sha256RoundCols<AB::Var> =
242            next[start_col..start_col + SHA256_ROUND_WIDTH].borrow();
243
244        let local_is_padding_row = local_cols.flags.is_padding_row();
245        // Note that there will always be a padding row in the trace since the unpadded height is a
246        // multiple of 17. So the next row is padding iff the current block is the last
247        // block in the trace.
248        let next_is_padding_row = next_cols.flags.is_padding_row();
249
250        // We check that the very last block has `is_last_block` set to true, which guarantees that
251        // there is at least one complete message. If other digest rows have `is_last_block` set to
252        // true, then the trace will be interpreted as containing multiple messages.
253        builder
254            .when(next_is_padding_row.clone())
255            .when(local_cols.flags.is_digest_row)
256            .assert_one(local_cols.flags.is_last_block);
257        // If we are in a round row, the next row cannot be a padding row
258        builder
259            .when(local_cols.flags.is_round_row)
260            .assert_zero(next_is_padding_row.clone());
261        // The first row must be a round row
262        builder
263            .when_first_row()
264            .assert_one(local_cols.flags.is_round_row);
265        // If we are in a padding row, the next row must also be a padding row
266        builder
267            .when_transition()
268            .when(local_is_padding_row.clone())
269            .assert_one(next_is_padding_row.clone());
270        // If we are in a digest row, the next row cannot be a digest row
271        builder
272            .when(local_cols.flags.is_digest_row)
273            .assert_zero(next_cols.flags.is_digest_row);
274        // Constrain how much the row index changes by
275        // round->round: 1
276        // round->digest: 1
277        // digest->round: -16
278        // digest->padding: 1
279        // padding->padding: 0
280        // Other transitions are not allowed by the above constraints
281        let delta = local_cols.flags.is_round_row * AB::Expr::ONE
282            + local_cols.flags.is_digest_row
283                * next_cols.flags.is_round_row
284                * AB::Expr::from_canonical_u32(16)
285                * AB::Expr::NEG_ONE
286            + local_cols.flags.is_digest_row * next_is_padding_row.clone() * AB::Expr::ONE;
287
288        let local_row_idx = self.row_idx_encoder.flag_with_val::<AB>(
289            &local_cols.flags.row_idx,
290            &(0..18).map(|i| (i, i)).collect::<Vec<_>>(),
291        );
292        let next_row_idx = self.row_idx_encoder.flag_with_val::<AB>(
293            &next_cols.flags.row_idx,
294            &(0..18).map(|i| (i, i)).collect::<Vec<_>>(),
295        );
296
297        builder
298            .when_transition()
299            .assert_eq(local_row_idx.clone() + delta, next_row_idx.clone());
300        builder.when_first_row().assert_zero(local_row_idx);
301
302        // Constrain the global block index
303        // We set the global block index to 0 for padding rows
304        // Starting with 1 so it is not the same as the padding rows
305
306        // Global block index is 1 on first row
307        builder
308            .when_first_row()
309            .assert_one(local_cols.flags.global_block_idx);
310
311        // Global block index is constant on all rows in a block
312        builder.when(local_cols.flags.is_round_row).assert_eq(
313            local_cols.flags.global_block_idx,
314            next_cols.flags.global_block_idx,
315        );
316        // Global block index increases by 1 between blocks
317        builder
318            .when_transition()
319            .when(local_cols.flags.is_digest_row)
320            .when(next_cols.flags.is_round_row)
321            .assert_eq(
322                local_cols.flags.global_block_idx + AB::Expr::ONE,
323                next_cols.flags.global_block_idx,
324            );
325        // Global block index is 0 on padding rows
326        builder
327            .when(local_is_padding_row.clone())
328            .assert_zero(local_cols.flags.global_block_idx);
329
330        // Constrain the local block index
331        // We set the local block index to 0 for padding rows
332
333        // Local block index is constant on all rows in a block
334        // and its value on padding rows is equal to its value on the first block
335        builder.when(not(local_cols.flags.is_digest_row)).assert_eq(
336            local_cols.flags.local_block_idx,
337            next_cols.flags.local_block_idx,
338        );
339        // Local block index increases by 1 between blocks in the same message
340        builder
341            .when(local_cols.flags.is_digest_row)
342            .when(not(local_cols.flags.is_last_block))
343            .assert_eq(
344                local_cols.flags.local_block_idx + AB::Expr::ONE,
345                next_cols.flags.local_block_idx,
346            );
347        // Local block index is 0 on padding rows
348        // Combined with the above, this means that the local block index is 0 in the first block
349        builder
350            .when(local_cols.flags.is_digest_row)
351            .when(local_cols.flags.is_last_block)
352            .assert_zero(next_cols.flags.local_block_idx);
353
354        self.eval_message_schedule::<AB>(builder, local_cols, next_cols);
355        self.eval_work_vars::<AB>(builder, local_cols, next_cols);
356        let next_cols: &Sha256DigestCols<AB::Var> =
357            next[start_col..start_col + SHA256_DIGEST_WIDTH].borrow();
358        self.eval_digest_row(builder, local_cols, next_cols);
359        let local_cols: &Sha256DigestCols<AB::Var> =
360            local[start_col..start_col + SHA256_DIGEST_WIDTH].borrow();
361        self.eval_prev_hash::<AB>(builder, local_cols, next_is_padding_row);
362    }
363
364    /// Constrains that the next block's `prev_hash` is equal to the current block's `hash`
365    /// Note: the constraining is done by interactions with the chip itself on every digest row
366    fn eval_prev_hash<AB: InteractionBuilder>(
367        &self,
368        builder: &mut AB,
369        local: &Sha256DigestCols<AB::Var>,
370        is_last_block_of_trace: AB::Expr, /* note this indicates the last block of the trace,
371                                           * not the last block of the message */
372    ) {
373        // Constrain that next block's `prev_hash` is equal to the current block's `hash`
374        let composed_hash: [[<AB as AirBuilder>::Expr; SHA256_WORD_U16S]; SHA256_HASH_WORDS] =
375            array::from_fn(|i| {
376                let hash_bits = if i < SHA256_ROUNDS_PER_ROW {
377                    local.hash.a[SHA256_ROUNDS_PER_ROW - 1 - i].map(|x| x.into())
378                } else {
379                    local.hash.e[SHA256_ROUNDS_PER_ROW + 3 - i].map(|x| x.into())
380                };
381                array::from_fn(|j| compose::<AB::Expr>(&hash_bits[j * 16..(j + 1) * 16], 1))
382            });
383        // Need to handle the case if this is the very last block of the trace matrix
384        let next_global_block_idx = select(
385            is_last_block_of_trace,
386            AB::Expr::ONE,
387            local.flags.global_block_idx + AB::Expr::ONE,
388        );
389        // The following interactions constrain certain values from block to block
390        self.bus.send(
391            builder,
392            composed_hash
393                .into_iter()
394                .flatten()
395                .chain(once(next_global_block_idx)),
396            local.flags.is_digest_row,
397        );
398
399        self.bus.receive(
400            builder,
401            local
402                .prev_hash
403                .into_iter()
404                .flatten()
405                .map(|x| x.into())
406                .chain(once(local.flags.global_block_idx.into())),
407            local.flags.is_digest_row,
408        );
409    }
410
411    /// Constrain the message schedule additions for `next` row
412    /// Note: For every addition we need to constrain the following for each of [SHA256_WORD_U16S]
413    /// limbs sig_1(w_{t-2})[i] + w_{t-7}[i] + sig_0(w_{t-15})[i] + w_{t-16}[i] +
414    /// carry_w[t][i-1] - carry_w[t][i] * 2^16 - w_t[i] == 0 Refer to [https://nvlpubs.nist.gov/nistpubs/FIPS/NIST.FIPS.180-4.pdf]
415    fn eval_message_schedule<AB: InteractionBuilder>(
416        &self,
417        builder: &mut AB,
418        local: &Sha256RoundCols<AB::Var>,
419        next: &Sha256RoundCols<AB::Var>,
420    ) {
421        // This `w` array contains 8 message schedule words - w_{idx}, ..., w_{idx+7} for some idx
422        let w = [local.message_schedule.w, next.message_schedule.w].concat();
423
424        // Constrain `w_3` for `next` row
425        for i in 0..SHA256_ROUNDS_PER_ROW - 1 {
426            // here we constrain the w_3 of the i_th word of the next row
427            // w_3 of next is w[i+4-3] = w[i+1]
428            let w_3 = w[i + 1].map(|x| x.into());
429            let expected_w_3 = next.schedule_helper.w_3[i];
430            for j in 0..SHA256_WORD_U16S {
431                let w_3_limb = compose::<AB::Expr>(&w_3[j * 16..(j + 1) * 16], 1);
432                builder
433                    .when(local.flags.is_round_row)
434                    .assert_eq(w_3_limb, expected_w_3[j].into());
435            }
436        }
437
438        // Constrain intermed for `next` row
439        // We will only constrain intermed_12 for rows [3, 14], and let it be unconstrained for
440        // other rows Other rows should put the needed value in intermed_12 to make the
441        // below summation constraint hold
442        let is_row_3_14 = self
443            .row_idx_encoder
444            .contains_flag_range::<AB>(&next.flags.row_idx, 3..=14);
445        // We will only constrain intermed_8 for rows [2, 13], and let it unconstrained for other
446        // rows
447        let is_row_2_13 = self
448            .row_idx_encoder
449            .contains_flag_range::<AB>(&next.flags.row_idx, 2..=13);
450        for i in 0..SHA256_ROUNDS_PER_ROW {
451            // w_idx
452            let w_idx = w[i].map(|x| x.into());
453            // sig_0(w_{idx+1})
454            let sig_w = small_sig0_field::<AB::Expr>(&w[i + 1]);
455            for j in 0..SHA256_WORD_U16S {
456                let w_idx_limb = compose::<AB::Expr>(&w_idx[j * 16..(j + 1) * 16], 1);
457                let sig_w_limb = compose::<AB::Expr>(&sig_w[j * 16..(j + 1) * 16], 1);
458
459                // We would like to constrain this only on rows 0..16, but we can't do a conditional
460                // check because the degree is already 3. So we must fill in
461                // `intermed_4` with dummy values on rows 0 and 16 to ensure the constraint holds on
462                // these rows.
463                builder.when_transition().assert_eq(
464                    next.schedule_helper.intermed_4[i][j],
465                    w_idx_limb + sig_w_limb,
466                );
467
468                builder.when(is_row_2_13.clone()).assert_eq(
469                    next.schedule_helper.intermed_8[i][j],
470                    local.schedule_helper.intermed_4[i][j],
471                );
472
473                builder.when(is_row_3_14.clone()).assert_eq(
474                    next.schedule_helper.intermed_12[i][j],
475                    local.schedule_helper.intermed_8[i][j],
476                );
477            }
478        }
479
480        // Constrain the message schedule additions for `next` row
481        for i in 0..SHA256_ROUNDS_PER_ROW {
482            // Note, here by w_{t} we mean the i_th word of the `next` row
483            // w_{t-7}
484            let w_7 = if i < 3 {
485                local.schedule_helper.w_3[i].map(|x| x.into())
486            } else {
487                let w_3 = w[i - 3].map(|x| x.into());
488                array::from_fn(|j| compose::<AB::Expr>(&w_3[j * 16..(j + 1) * 16], 1))
489            };
490            // sig_0(w_{t-15}) + w_{t-16}
491            let intermed_16 = local.schedule_helper.intermed_12[i].map(|x| x.into());
492
493            let carries = array::from_fn(|j| {
494                next.message_schedule.carry_or_buffer[i][j * 2]
495                    + AB::Expr::TWO * next.message_schedule.carry_or_buffer[i][j * 2 + 1]
496            });
497
498            // Constrain `W_{idx} = sig_1(W_{idx-2}) + W_{idx-7} + sig_0(W_{idx-15}) + W_{idx-16}`
499            // We would like to constrain this only on rows 4..16, but we can't do a conditional
500            // check because the degree of sum is already 3 So we must fill in
501            // `intermed_12` with dummy values on rows 0..3 and 15 and 16 to ensure the constraint
502            // holds on rows 0..4 and 16. Note that the dummy value goes in the previous
503            // row to make the current row's constraint hold.
504            constraint_word_addition(
505                // Note: here we can't do a conditional check because the degree of sum is already
506                // 3
507                &mut builder.when_transition(),
508                &[&small_sig1_field::<AB::Expr>(&w[i + 2])],
509                &[&w_7, &intermed_16],
510                &w[i + 4],
511                &carries,
512            );
513
514            for j in 0..SHA256_WORD_U16S {
515                // When on rows 4..16 message schedule carries should be 0 or 1
516                let is_row_4_15 = next.flags.is_round_row - next.flags.is_first_4_rows;
517                builder
518                    .when(is_row_4_15.clone())
519                    .assert_bool(next.message_schedule.carry_or_buffer[i][j * 2]);
520                builder
521                    .when(is_row_4_15)
522                    .assert_bool(next.message_schedule.carry_or_buffer[i][j * 2 + 1]);
523            }
524            // Constrain w being composed of bits
525            for j in 0..SHA256_WORD_BITS {
526                builder
527                    .when(next.flags.is_round_row)
528                    .assert_bool(next.message_schedule.w[i][j]);
529            }
530        }
531    }
532
533    /// Constrain the work vars on `next` row according to the sha256 documentation
534    /// Refer to [https://nvlpubs.nist.gov/nistpubs/FIPS/NIST.FIPS.180-4.pdf]
535    fn eval_work_vars<AB: InteractionBuilder>(
536        &self,
537        builder: &mut AB,
538        local: &Sha256RoundCols<AB::Var>,
539        next: &Sha256RoundCols<AB::Var>,
540    ) {
541        let a = [local.work_vars.a, next.work_vars.a].concat();
542        let e = [local.work_vars.e, next.work_vars.e].concat();
543        for i in 0..SHA256_ROUNDS_PER_ROW {
544            for j in 0..SHA256_WORD_U16S {
545                // Although we need carry_a <= 6 and carry_e <= 5, constraining carry_a, carry_e in
546                // [0, 2^8) is enough to prevent overflow and ensure the soundness
547                // of the addition we want to check
548                self.bitwise_lookup_bus
549                    .send_range(local.work_vars.carry_a[i][j], local.work_vars.carry_e[i][j])
550                    .eval(builder, local.flags.is_round_row);
551            }
552
553            let w_limbs = array::from_fn(|j| {
554                compose::<AB::Expr>(&next.message_schedule.w[i][j * 16..(j + 1) * 16], 1)
555                    * next.flags.is_round_row
556            });
557            let k_limbs = array::from_fn(|j| {
558                self.row_idx_encoder.flag_with_val::<AB>(
559                    &next.flags.row_idx,
560                    &(0..16)
561                        .map(|rw_idx| {
562                            (
563                                rw_idx,
564                                u32_into_limbs::<SHA256_WORD_U16S>(
565                                    SHA256_K[rw_idx * SHA256_ROUNDS_PER_ROW + i],
566                                )[j] as usize,
567                            )
568                        })
569                        .collect::<Vec<_>>(),
570                )
571            });
572
573            // Constrain `a = h + sig_1(e) + ch(e, f, g) + K + W + sig_0(a) + Maj(a, b, c)`
574            // We have to enforce this constraint on all rows since the degree of the constraint is
575            // already 3. So, we must fill in `carry_a` with dummy values on digest rows
576            // to ensure the constraint holds.
577            constraint_word_addition(
578                builder,
579                &[
580                    &e[i].map(|x| x.into()),                // previous `h`
581                    &big_sig1_field::<AB::Expr>(&e[i + 3]), // sig_1 of previous `e`
582                    &ch_field::<AB::Expr>(&e[i + 3], &e[i + 2], &e[i + 1]), /* Ch of previous
583                                                             * `e`, `f`, `g` */
584                    &big_sig0_field::<AB::Expr>(&a[i + 3]), // sig_0 of previous `a`
585                    &maj_field::<AB::Expr>(&a[i + 3], &a[i + 2], &a[i + 1]), /* Maj of previous
586                                                             * a, b, c */
587                ],
588                &[&w_limbs, &k_limbs],      // K and W
589                &a[i + 4],                  // new `a`
590                &next.work_vars.carry_a[i], // carries of addition
591            );
592
593            // Constrain `e = d + h + sig_1(e) + ch(e, f, g) + K + W`
594            // We have to enforce this constraint on all rows since the degree of the constraint is
595            // already 3. So, we must fill in `carry_e` with dummy values on digest rows
596            // to ensure the constraint holds.
597            constraint_word_addition(
598                builder,
599                &[
600                    &a[i].map(|x| x.into()), // previous `d`
601                    &e[i].map(|x| x.into()), // previous `h`
602                    &big_sig1_field::<AB::Expr>(&e[i + 3]), /* sig_1 of previous
603                                              * `e` */
604                    &ch_field::<AB::Expr>(&e[i + 3], &e[i + 2], &e[i + 1]), /* Ch of previous
605                                                                             * `e`, `f`, `g` */
606                ],
607                &[&w_limbs, &k_limbs],      // K and W
608                &e[i + 4],                  // new `e`
609                &next.work_vars.carry_e[i], // carries of addition
610            );
611        }
612    }
613}