openvm_sha256_air/
trace.rs

1use std::{array, borrow::BorrowMut, ops::Range};
2
3use openvm_circuit_primitives::{
4    bitwise_op_lookup::SharedBitwiseOperationLookupChip, utils::next_power_of_two_or_zero,
5};
6use openvm_stark_backend::{
7    p3_air::BaseAir, p3_field::PrimeField32, p3_matrix::dense::RowMajorMatrix,
8    p3_maybe_rayon::prelude::*,
9};
10use sha2::{compress256, digest::generic_array::GenericArray};
11
12use super::{
13    air::Sha256Air, big_sig0_field, big_sig1_field, ch_field, columns::Sha256RoundCols, compose,
14    get_flag_pt_array, maj_field, small_sig0_field, small_sig1_field, SHA256_BLOCK_WORDS,
15    SHA256_DIGEST_WIDTH, SHA256_HASH_WORDS, SHA256_ROUND_WIDTH,
16};
17use crate::{
18    big_sig0, big_sig1, ch, columns::Sha256DigestCols, limbs_into_u32, maj, small_sig0, small_sig1,
19    u32_into_limbs, SHA256_BLOCK_U8S, SHA256_BUFFER_SIZE, SHA256_H, SHA256_INVALID_CARRY_A,
20    SHA256_INVALID_CARRY_E, SHA256_K, SHA256_ROUNDS_PER_ROW, SHA256_ROWS_PER_BLOCK,
21    SHA256_WORD_BITS, SHA256_WORD_U16S, SHA256_WORD_U8S,
22};
23
24/// The trace generation of SHA256 should be done in two passes.
25/// The first pass should do `get_block_trace` for every block and generate the invalid rows through
26/// `get_default_row` The second pass should go through all the blocks and call
27/// `generate_missing_cells`
28impl Sha256Air {
29    /// This function takes the input_message (padding not handled), the previous hash,
30    /// and returns the new hash after processing the block input
31    pub fn get_block_hash(
32        prev_hash: &[u32; SHA256_HASH_WORDS],
33        input: [u8; SHA256_BLOCK_U8S],
34    ) -> [u32; SHA256_HASH_WORDS] {
35        let mut new_hash = *prev_hash;
36        let input_array = [GenericArray::from(input)];
37        compress256(&mut new_hash, &input_array);
38        new_hash
39    }
40
41    /// This function takes a 512-bit chunk of the input message (padding not handled), the previous
42    /// hash, a flag indicating if it's the last block, the global block index, the local block
43    /// index, and the buffer values that will be put in rows 0..4.
44    /// Will populate the given `trace` with the trace of the block, where the width of the trace is
45    /// `trace_width` and the starting column for the `Sha256Air` is `trace_start_col`.
46    /// **Note**: this function only generates some of the required trace. Another pass is required,
47    /// refer to [`Self::generate_missing_cells`] for details.
48    #[allow(clippy::too_many_arguments)]
49    pub fn generate_block_trace<F: PrimeField32>(
50        &self,
51        trace: &mut [F],
52        trace_width: usize,
53        trace_start_col: usize,
54        input: &[u32; SHA256_BLOCK_WORDS],
55        bitwise_lookup_chip: SharedBitwiseOperationLookupChip<8>,
56        prev_hash: &[u32; SHA256_HASH_WORDS],
57        is_last_block: bool,
58        global_block_idx: u32,
59        local_block_idx: u32,
60        buffer_vals: &[[F; SHA256_BUFFER_SIZE]; 4],
61    ) {
62        #[cfg(debug_assertions)]
63        {
64            assert!(trace.len() == trace_width * SHA256_ROWS_PER_BLOCK);
65            assert!(trace_start_col + super::SHA256_WIDTH <= trace_width);
66            assert!(self.bitwise_lookup_bus == bitwise_lookup_chip.bus());
67            if local_block_idx == 0 {
68                assert!(*prev_hash == SHA256_H);
69            }
70        }
71        let get_range = |start: usize, len: usize| -> Range<usize> { start..start + len };
72        let mut message_schedule = [0u32; 64];
73        message_schedule[..input.len()].copy_from_slice(input);
74        let mut work_vars = *prev_hash;
75        for (i, row) in trace.chunks_exact_mut(trace_width).enumerate() {
76            // doing the 64 rounds in 16 rows
77            if i < 16 {
78                let cols: &mut Sha256RoundCols<F> =
79                    row[get_range(trace_start_col, SHA256_ROUND_WIDTH)].borrow_mut();
80                cols.flags.is_round_row = F::ONE;
81                cols.flags.is_first_4_rows = if i < 4 { F::ONE } else { F::ZERO };
82                cols.flags.is_digest_row = F::ZERO;
83                cols.flags.is_last_block = F::from_bool(is_last_block);
84                cols.flags.row_idx =
85                    get_flag_pt_array(&self.row_idx_encoder, i).map(F::from_canonical_u32);
86                cols.flags.global_block_idx = F::from_canonical_u32(global_block_idx);
87                cols.flags.local_block_idx = F::from_canonical_u32(local_block_idx);
88
89                // W_idx = M_idx
90                if i < SHA256_ROWS_PER_BLOCK / SHA256_ROUNDS_PER_ROW {
91                    for j in 0..SHA256_ROUNDS_PER_ROW {
92                        cols.message_schedule.w[j] = u32_into_limbs::<SHA256_WORD_BITS>(
93                            input[i * SHA256_ROUNDS_PER_ROW + j],
94                        )
95                        .map(F::from_canonical_u32);
96                        cols.message_schedule.carry_or_buffer[j] =
97                            array::from_fn(|k| buffer_vals[i][j * SHA256_WORD_U16S * 2 + k]);
98                    }
99                }
100                // W_idx = SIG1(W_{idx-2}) + W_{idx-7} + SIG0(W_{idx-15}) + W_{idx-16}
101                else {
102                    for j in 0..SHA256_ROUNDS_PER_ROW {
103                        let idx = i * SHA256_ROUNDS_PER_ROW + j;
104                        let nums: [u32; 4] = [
105                            small_sig1(message_schedule[idx - 2]),
106                            message_schedule[idx - 7],
107                            small_sig0(message_schedule[idx - 15]),
108                            message_schedule[idx - 16],
109                        ];
110                        let w: u32 = nums.iter().fold(0, |acc, &num| acc.wrapping_add(num));
111                        cols.message_schedule.w[j] =
112                            u32_into_limbs::<SHA256_WORD_BITS>(w).map(F::from_canonical_u32);
113
114                        let nums_limbs = nums
115                            .iter()
116                            .map(|x| u32_into_limbs::<SHA256_WORD_U16S>(*x))
117                            .collect::<Vec<_>>();
118                        let w_limbs = u32_into_limbs::<SHA256_WORD_U16S>(w);
119
120                        // fill in the carrys
121                        for k in 0..SHA256_WORD_U16S {
122                            let mut sum = nums_limbs.iter().fold(0, |acc, num| acc + num[k]);
123                            if k > 0 {
124                                sum += (cols.message_schedule.carry_or_buffer[j][k * 2 - 2]
125                                    + F::TWO * cols.message_schedule.carry_or_buffer[j][k * 2 - 1])
126                                    .as_canonical_u32();
127                            }
128                            let carry = (sum - w_limbs[k]) >> 16;
129                            cols.message_schedule.carry_or_buffer[j][k * 2] =
130                                F::from_canonical_u32(carry & 1);
131                            cols.message_schedule.carry_or_buffer[j][k * 2 + 1] =
132                                F::from_canonical_u32(carry >> 1);
133                        }
134                        // update the message schedule
135                        message_schedule[idx] = w;
136                    }
137                }
138                // fill in the work variables
139                for j in 0..SHA256_ROUNDS_PER_ROW {
140                    // t1 = h + SIG1(e) + ch(e, f, g) + K_idx + W_idx
141                    let t1 = [
142                        work_vars[7],
143                        big_sig1(work_vars[4]),
144                        ch(work_vars[4], work_vars[5], work_vars[6]),
145                        SHA256_K[i * SHA256_ROUNDS_PER_ROW + j],
146                        limbs_into_u32(cols.message_schedule.w[j].map(|f| f.as_canonical_u32())),
147                    ];
148                    let t1_sum: u32 = t1.iter().fold(0, |acc, &num| acc.wrapping_add(num));
149
150                    // t2 = SIG0(a) + maj(a, b, c)
151                    let t2 = [
152                        big_sig0(work_vars[0]),
153                        maj(work_vars[0], work_vars[1], work_vars[2]),
154                    ];
155
156                    let t2_sum: u32 = t2.iter().fold(0, |acc, &num| acc.wrapping_add(num));
157
158                    // e = d + t1
159                    let e = work_vars[3].wrapping_add(t1_sum);
160                    cols.work_vars.e[j] =
161                        u32_into_limbs::<SHA256_WORD_BITS>(e).map(F::from_canonical_u32);
162                    let e_limbs = u32_into_limbs::<SHA256_WORD_U16S>(e);
163                    // a = t1 + t2
164                    let a = t1_sum.wrapping_add(t2_sum);
165                    cols.work_vars.a[j] =
166                        u32_into_limbs::<SHA256_WORD_BITS>(a).map(F::from_canonical_u32);
167                    let a_limbs = u32_into_limbs::<SHA256_WORD_U16S>(a);
168                    // fill in the carrys
169                    for k in 0..SHA256_WORD_U16S {
170                        let t1_limb = t1.iter().fold(0, |acc, &num| {
171                            acc + u32_into_limbs::<SHA256_WORD_U16S>(num)[k]
172                        });
173                        let t2_limb = t2.iter().fold(0, |acc, &num| {
174                            acc + u32_into_limbs::<SHA256_WORD_U16S>(num)[k]
175                        });
176
177                        let mut e_limb =
178                            t1_limb + u32_into_limbs::<SHA256_WORD_U16S>(work_vars[3])[k];
179                        let mut a_limb = t1_limb + t2_limb;
180                        if k > 0 {
181                            a_limb += cols.work_vars.carry_a[j][k - 1].as_canonical_u32();
182                            e_limb += cols.work_vars.carry_e[j][k - 1].as_canonical_u32();
183                        }
184                        let carry_a = (a_limb - a_limbs[k]) >> 16;
185                        let carry_e = (e_limb - e_limbs[k]) >> 16;
186                        cols.work_vars.carry_a[j][k] = F::from_canonical_u32(carry_a);
187                        cols.work_vars.carry_e[j][k] = F::from_canonical_u32(carry_e);
188                        bitwise_lookup_chip.request_range(carry_a, carry_e);
189                    }
190
191                    // update working variables
192                    work_vars[7] = work_vars[6];
193                    work_vars[6] = work_vars[5];
194                    work_vars[5] = work_vars[4];
195                    work_vars[4] = e;
196                    work_vars[3] = work_vars[2];
197                    work_vars[2] = work_vars[1];
198                    work_vars[1] = work_vars[0];
199                    work_vars[0] = a;
200                }
201
202                // filling w_3 and intermed_4 here and the rest later
203                if i > 0 {
204                    for j in 0..SHA256_ROUNDS_PER_ROW {
205                        let idx = i * SHA256_ROUNDS_PER_ROW + j;
206                        let w_4 = u32_into_limbs::<SHA256_WORD_U16S>(message_schedule[idx - 4]);
207                        let sig_0_w_3 = u32_into_limbs::<SHA256_WORD_U16S>(small_sig0(
208                            message_schedule[idx - 3],
209                        ));
210                        cols.schedule_helper.intermed_4[j] =
211                            array::from_fn(|k| F::from_canonical_u32(w_4[k] + sig_0_w_3[k]));
212                        if j < SHA256_ROUNDS_PER_ROW - 1 {
213                            let w_3 = message_schedule[idx - 3];
214                            cols.schedule_helper.w_3[j] =
215                                u32_into_limbs::<SHA256_WORD_U16S>(w_3).map(F::from_canonical_u32);
216                        }
217                    }
218                }
219            }
220            // generate the digest row
221            else {
222                let cols: &mut Sha256DigestCols<F> =
223                    row[get_range(trace_start_col, SHA256_DIGEST_WIDTH)].borrow_mut();
224                for j in 0..SHA256_ROUNDS_PER_ROW - 1 {
225                    let w_3 = message_schedule[i * SHA256_ROUNDS_PER_ROW + j - 3];
226                    cols.schedule_helper.w_3[j] =
227                        u32_into_limbs::<SHA256_WORD_U16S>(w_3).map(F::from_canonical_u32);
228                }
229                cols.flags.is_round_row = F::ZERO;
230                cols.flags.is_first_4_rows = F::ZERO;
231                cols.flags.is_digest_row = F::ONE;
232                cols.flags.is_last_block = F::from_bool(is_last_block);
233                cols.flags.row_idx =
234                    get_flag_pt_array(&self.row_idx_encoder, 16).map(F::from_canonical_u32);
235                cols.flags.global_block_idx = F::from_canonical_u32(global_block_idx);
236
237                cols.flags.local_block_idx = F::from_canonical_u32(local_block_idx);
238                let final_hash: [u32; SHA256_HASH_WORDS] =
239                    array::from_fn(|i| work_vars[i].wrapping_add(prev_hash[i]));
240                let final_hash_limbs: [[u32; SHA256_WORD_U8S]; SHA256_HASH_WORDS] =
241                    array::from_fn(|i| u32_into_limbs::<SHA256_WORD_U8S>(final_hash[i]));
242                // need to ensure final hash limbs are bytes, in order for
243                //   prev_hash[i] + work_vars[i] == final_hash[i]
244                // to be constrained correctly
245                for word in final_hash_limbs.iter() {
246                    for chunk in word.chunks(2) {
247                        bitwise_lookup_chip.request_range(chunk[0], chunk[1]);
248                    }
249                }
250                cols.final_hash = array::from_fn(|i| {
251                    array::from_fn(|j| F::from_canonical_u32(final_hash_limbs[i][j]))
252                });
253                cols.prev_hash = prev_hash
254                    .map(|f| u32_into_limbs::<SHA256_WORD_U16S>(f).map(F::from_canonical_u32));
255                let hash = if is_last_block {
256                    SHA256_H.map(u32_into_limbs::<SHA256_WORD_BITS>)
257                } else {
258                    cols.final_hash
259                        .map(|f| limbs_into_u32(f.map(|x| x.as_canonical_u32())))
260                        .map(u32_into_limbs::<SHA256_WORD_BITS>)
261                }
262                .map(|x| x.map(F::from_canonical_u32));
263
264                for i in 0..SHA256_ROUNDS_PER_ROW {
265                    cols.hash.a[i] = hash[SHA256_ROUNDS_PER_ROW - i - 1];
266                    cols.hash.e[i] = hash[SHA256_ROUNDS_PER_ROW - i + 3];
267                }
268            }
269        }
270
271        for i in 0..SHA256_ROWS_PER_BLOCK - 1 {
272            let rows = &mut trace[i * trace_width..(i + 2) * trace_width];
273            let (local, next) = rows.split_at_mut(trace_width);
274            let local_cols: &mut Sha256RoundCols<F> =
275                local[get_range(trace_start_col, SHA256_ROUND_WIDTH)].borrow_mut();
276            let next_cols: &mut Sha256RoundCols<F> =
277                next[get_range(trace_start_col, SHA256_ROUND_WIDTH)].borrow_mut();
278            if i > 0 {
279                for j in 0..SHA256_ROUNDS_PER_ROW {
280                    next_cols.schedule_helper.intermed_8[j] =
281                        local_cols.schedule_helper.intermed_4[j];
282                    if (2..SHA256_ROWS_PER_BLOCK - 3).contains(&i) {
283                        next_cols.schedule_helper.intermed_12[j] =
284                            local_cols.schedule_helper.intermed_8[j];
285                    }
286                }
287            }
288            if i == SHA256_ROWS_PER_BLOCK - 2 {
289                // `next` is a digest row.
290                // Fill in `carry_a` and `carry_e` with dummy values so the constraints on `a` and
291                // `e` hold.
292                Self::generate_carry_ae(local_cols, next_cols);
293                // Fill in row 16's `intermed_4` with dummy values so the message schedule
294                // constraints holds on that row
295                Self::generate_intermed_4(local_cols, next_cols);
296            }
297            if i <= 2 {
298                // i is in 0..3.
299                // Fill in `local.intermed_12` with dummy values so the message schedule constraints
300                // hold on rows 1..4.
301                Self::generate_intermed_12(local_cols, next_cols);
302            }
303        }
304    }
305
306    /// This function will fill in the cells that we couldn't do during the first pass.
307    /// This function should be called only after `generate_block_trace` was called for all blocks
308    /// And [`Self::generate_default_row`] is called for all invalid rows
309    /// Will populate the missing values of `trace`, where the width of the trace is `trace_width`
310    /// and the starting column for the `Sha256Air` is `trace_start_col`.
311    /// Note: `trace` needs to be the rows 1..17 of a block and the first row of the next block
312    pub fn generate_missing_cells<F: PrimeField32>(
313        &self,
314        trace: &mut [F],
315        trace_width: usize,
316        trace_start_col: usize,
317    ) {
318        // Here row_17 = next blocks row 0
319        let rows_15_17 = &mut trace[14 * trace_width..17 * trace_width];
320        let (row_15, row_16_17) = rows_15_17.split_at_mut(trace_width);
321        let (row_16, row_17) = row_16_17.split_at_mut(trace_width);
322        let cols_15: &mut Sha256RoundCols<F> =
323            row_15[trace_start_col..trace_start_col + SHA256_ROUND_WIDTH].borrow_mut();
324        let cols_16: &mut Sha256RoundCols<F> =
325            row_16[trace_start_col..trace_start_col + SHA256_ROUND_WIDTH].borrow_mut();
326        let cols_17: &mut Sha256RoundCols<F> =
327            row_17[trace_start_col..trace_start_col + SHA256_ROUND_WIDTH].borrow_mut();
328        // Fill in row 15's `intermed_12` with dummy values so the message schedule constraints
329        // holds on row 16
330        Self::generate_intermed_12(cols_15, cols_16);
331        // Fill in row 16's `intermed_12` with dummy values so the message schedule constraints
332        // holds on the next block's row 0
333        Self::generate_intermed_12(cols_16, cols_17);
334        // Fill in row 0's `intermed_4` with dummy values so the message schedule constraints holds
335        // on that row
336        Self::generate_intermed_4(cols_16, cols_17);
337    }
338
339    /// Fills the `cols` as a padding row
340    /// Note: we still need to correctly fill in the hash values, carries and intermeds
341    pub fn generate_default_row<F: PrimeField32>(self: &Sha256Air, cols: &mut Sha256RoundCols<F>) {
342        cols.flags.is_round_row = F::ZERO;
343        cols.flags.is_first_4_rows = F::ZERO;
344        cols.flags.is_digest_row = F::ZERO;
345
346        cols.flags.is_last_block = F::ZERO;
347        cols.flags.global_block_idx = F::ZERO;
348        cols.flags.row_idx =
349            get_flag_pt_array(&self.row_idx_encoder, 17).map(F::from_canonical_u32);
350        cols.flags.local_block_idx = F::ZERO;
351
352        cols.message_schedule.w = [[F::ZERO; SHA256_WORD_BITS]; SHA256_ROUNDS_PER_ROW];
353        cols.message_schedule.carry_or_buffer =
354            [[F::ZERO; SHA256_WORD_U16S * 2]; SHA256_ROUNDS_PER_ROW];
355
356        let hash = SHA256_H
357            .map(u32_into_limbs::<SHA256_WORD_BITS>)
358            .map(|x| x.map(F::from_canonical_u32));
359
360        for i in 0..SHA256_ROUNDS_PER_ROW {
361            cols.work_vars.a[i] = hash[SHA256_ROUNDS_PER_ROW - i - 1];
362            cols.work_vars.e[i] = hash[SHA256_ROUNDS_PER_ROW - i + 3];
363        }
364
365        cols.work_vars.carry_a = array::from_fn(|i| {
366            array::from_fn(|j| F::from_canonical_u32(SHA256_INVALID_CARRY_A[i][j]))
367        });
368        cols.work_vars.carry_e = array::from_fn(|i| {
369            array::from_fn(|j| F::from_canonical_u32(SHA256_INVALID_CARRY_E[i][j]))
370        });
371    }
372
373    /// The following functions do the calculations in native field since they will be called on
374    /// padding rows which can overflow and we need to make sure it matches the AIR constraints
375    /// Puts the correct carrys in the `next_row`, the resulting carrys can be out of bound
376    fn generate_carry_ae<F: PrimeField32>(
377        local_cols: &Sha256RoundCols<F>,
378        next_cols: &mut Sha256RoundCols<F>,
379    ) {
380        let a = [local_cols.work_vars.a, next_cols.work_vars.a].concat();
381        let e = [local_cols.work_vars.e, next_cols.work_vars.e].concat();
382        for i in 0..SHA256_ROUNDS_PER_ROW {
383            let cur_a = a[i + 4];
384            let sig_a = big_sig0_field::<F>(&a[i + 3]);
385            let maj_abc = maj_field::<F>(&a[i + 3], &a[i + 2], &a[i + 1]);
386            let d = a[i];
387            let cur_e = e[i + 4];
388            let sig_e = big_sig1_field::<F>(&e[i + 3]);
389            let ch_efg = ch_field::<F>(&e[i + 3], &e[i + 2], &e[i + 1]);
390            let h = e[i];
391
392            let t1 = [h, sig_e, ch_efg];
393            let t2 = [sig_a, maj_abc];
394            for j in 0..SHA256_WORD_U16S {
395                let t1_limb_sum = t1.iter().fold(F::ZERO, |acc, x| {
396                    acc + compose::<F>(&x[j * 16..(j + 1) * 16], 1)
397                });
398                let t2_limb_sum = t2.iter().fold(F::ZERO, |acc, x| {
399                    acc + compose::<F>(&x[j * 16..(j + 1) * 16], 1)
400                });
401                let d_limb = compose::<F>(&d[j * 16..(j + 1) * 16], 1);
402                let cur_a_limb = compose::<F>(&cur_a[j * 16..(j + 1) * 16], 1);
403                let cur_e_limb = compose::<F>(&cur_e[j * 16..(j + 1) * 16], 1);
404                let sum = d_limb
405                    + t1_limb_sum
406                    + if j == 0 {
407                        F::ZERO
408                    } else {
409                        next_cols.work_vars.carry_e[i][j - 1]
410                    }
411                    - cur_e_limb;
412                let carry_e = sum * (F::from_canonical_u32(1 << 16).inverse());
413
414                let sum = t1_limb_sum
415                    + t2_limb_sum
416                    + if j == 0 {
417                        F::ZERO
418                    } else {
419                        next_cols.work_vars.carry_a[i][j - 1]
420                    }
421                    - cur_a_limb;
422                let carry_a = sum * (F::from_canonical_u32(1 << 16).inverse());
423                next_cols.work_vars.carry_e[i][j] = carry_e;
424                next_cols.work_vars.carry_a[i][j] = carry_a;
425            }
426        }
427    }
428
429    /// Puts the correct intermed_4 in the `next_row`
430    fn generate_intermed_4<F: PrimeField32>(
431        local_cols: &Sha256RoundCols<F>,
432        next_cols: &mut Sha256RoundCols<F>,
433    ) {
434        let w = [local_cols.message_schedule.w, next_cols.message_schedule.w].concat();
435        let w_limbs: Vec<[F; SHA256_WORD_U16S]> = w
436            .iter()
437            .map(|x| array::from_fn(|i| compose::<F>(&x[i * 16..(i + 1) * 16], 1)))
438            .collect();
439        for i in 0..SHA256_ROUNDS_PER_ROW {
440            let sig_w = small_sig0_field::<F>(&w[i + 1]);
441            let sig_w_limbs: [F; SHA256_WORD_U16S] =
442                array::from_fn(|j| compose::<F>(&sig_w[j * 16..(j + 1) * 16], 1));
443            for (j, sig_w_limb) in sig_w_limbs.iter().enumerate() {
444                next_cols.schedule_helper.intermed_4[i][j] = w_limbs[i][j] + *sig_w_limb;
445            }
446        }
447    }
448
449    /// Puts the needed intermed_12 in the `local_row`
450    fn generate_intermed_12<F: PrimeField32>(
451        local_cols: &mut Sha256RoundCols<F>,
452        next_cols: &Sha256RoundCols<F>,
453    ) {
454        let w = [local_cols.message_schedule.w, next_cols.message_schedule.w].concat();
455        let w_limbs: Vec<[F; SHA256_WORD_U16S]> = w
456            .iter()
457            .map(|x| array::from_fn(|i| compose::<F>(&x[i * 16..(i + 1) * 16], 1)))
458            .collect();
459        for i in 0..SHA256_ROUNDS_PER_ROW {
460            // sig_1(w_{t-2})
461            let sig_w_2: [F; SHA256_WORD_U16S] = array::from_fn(|j| {
462                compose::<F>(&small_sig1_field::<F>(&w[i + 2])[j * 16..(j + 1) * 16], 1)
463            });
464            // w_{t-7}
465            let w_7 = if i < 3 {
466                local_cols.schedule_helper.w_3[i]
467            } else {
468                w_limbs[i - 3]
469            };
470            // w_t
471            let w_cur = w_limbs[i + 4];
472            for j in 0..SHA256_WORD_U16S {
473                let carry = next_cols.message_schedule.carry_or_buffer[i][j * 2]
474                    + F::TWO * next_cols.message_schedule.carry_or_buffer[i][j * 2 + 1];
475                let sum = sig_w_2[j] + w_7[j] - carry * F::from_canonical_u32(1 << 16) - w_cur[j]
476                    + if j > 0 {
477                        next_cols.message_schedule.carry_or_buffer[i][j * 2 - 2]
478                            + F::from_canonical_u32(2)
479                                * next_cols.message_schedule.carry_or_buffer[i][j * 2 - 1]
480                    } else {
481                        F::ZERO
482                    };
483                local_cols.schedule_helper.intermed_12[i][j] = -sum;
484            }
485        }
486    }
487}
488
489/// `records` consists of pairs of `(input_block, is_last_block)`.
490pub fn generate_trace<F: PrimeField32>(
491    sub_air: &Sha256Air,
492    bitwise_lookup_chip: SharedBitwiseOperationLookupChip<8>,
493    records: Vec<([u8; SHA256_BLOCK_U8S], bool)>,
494) -> RowMajorMatrix<F> {
495    let non_padded_height = records.len() * SHA256_ROWS_PER_BLOCK;
496    let height = next_power_of_two_or_zero(non_padded_height);
497    let width = <Sha256Air as BaseAir<F>>::width(sub_air);
498    let mut values = F::zero_vec(height * width);
499
500    struct BlockContext {
501        prev_hash: [u32; 8],
502        local_block_idx: u32,
503        global_block_idx: u32,
504        input: [u8; SHA256_BLOCK_U8S],
505        is_last_block: bool,
506    }
507    let mut block_ctx: Vec<BlockContext> = Vec::with_capacity(records.len());
508    let mut prev_hash = SHA256_H;
509    let mut local_block_idx = 0;
510    let mut global_block_idx = 1;
511    for (input, is_last_block) in records {
512        block_ctx.push(BlockContext {
513            prev_hash,
514            local_block_idx,
515            global_block_idx,
516            input,
517            is_last_block,
518        });
519        global_block_idx += 1;
520        if is_last_block {
521            local_block_idx = 0;
522            prev_hash = SHA256_H;
523        } else {
524            local_block_idx += 1;
525            prev_hash = Sha256Air::get_block_hash(&prev_hash, input);
526        }
527    }
528    // first pass
529    values
530        .par_chunks_exact_mut(width * SHA256_ROWS_PER_BLOCK)
531        .zip(block_ctx)
532        .for_each(|(block, ctx)| {
533            let BlockContext {
534                prev_hash,
535                local_block_idx,
536                global_block_idx,
537                input,
538                is_last_block,
539            } = ctx;
540            let input_words = array::from_fn(|i| {
541                limbs_into_u32::<SHA256_WORD_U8S>(array::from_fn(|j| {
542                    input[(i + 1) * SHA256_WORD_U8S - j - 1] as u32
543                }))
544            });
545            sub_air.generate_block_trace(
546                block,
547                width,
548                0,
549                &input_words,
550                bitwise_lookup_chip.clone(),
551                &prev_hash,
552                is_last_block,
553                global_block_idx,
554                local_block_idx,
555                &[[F::ZERO; 16]; 4],
556            );
557        });
558    // second pass: padding rows
559    values[width * non_padded_height..]
560        .par_chunks_mut(width)
561        .for_each(|row| {
562            let cols: &mut Sha256RoundCols<F> = row.borrow_mut();
563            sub_air.generate_default_row(cols);
564        });
565    // second pass: non-padding rows
566    values[width..]
567        .par_chunks_mut(width * SHA256_ROWS_PER_BLOCK)
568        .take(non_padded_height / SHA256_ROWS_PER_BLOCK)
569        .for_each(|chunk| {
570            sub_air.generate_missing_cells(chunk, width, 0);
571        });
572    RowMajorMatrix::new(values, width)
573}