openvm_native_circuit/poseidon2/
trace.rs

1use std::{borrow::BorrowMut, sync::Arc};
2
3use openvm_circuit::system::memory::{MemoryAuxColsFactory, OfflineMemory};
4use openvm_circuit_primitives::utils::next_power_of_two_or_zero;
5use openvm_instructions::{instruction::Instruction, LocalOpcode};
6use openvm_native_compiler::Poseidon2Opcode::COMP_POS2;
7use openvm_stark_backend::{
8    config::{StarkGenericConfig, Val},
9    p3_air::BaseAir,
10    p3_field::{Field, PrimeField32},
11    p3_matrix::dense::RowMajorMatrix,
12    p3_maybe_rayon::prelude::*,
13    prover::types::AirProofInput,
14    AirRef, Chip, ChipUsageGetter,
15};
16
17use crate::{
18    chip::{SimplePoseidonRecord, NUM_INITIAL_READS},
19    poseidon2::{
20        chip::{
21            CellRecord, IncorporateRowRecord, IncorporateSiblingRecord, InsideRowRecord,
22            NativePoseidon2Chip, VerifyBatchRecord,
23        },
24        columns::{
25            InsideRowSpecificCols, NativePoseidon2Cols, SimplePoseidonSpecificCols,
26            TopLevelSpecificCols,
27        },
28        CHUNK,
29    },
30};
31impl<F: Field, const SBOX_REGISTERS: usize> ChipUsageGetter
32    for NativePoseidon2Chip<F, SBOX_REGISTERS>
33{
34    fn air_name(&self) -> String {
35        "VerifyBatchAir".to_string()
36    }
37
38    fn current_trace_height(&self) -> usize {
39        self.height
40    }
41
42    fn trace_width(&self) -> usize {
43        NativePoseidon2Cols::<F, SBOX_REGISTERS>::width()
44    }
45}
46
47impl<F: PrimeField32, const SBOX_REGISTERS: usize> NativePoseidon2Chip<F, SBOX_REGISTERS> {
48    fn generate_subair_cols(&self, input: [F; 2 * CHUNK], cols: &mut [F]) {
49        let inner_trace = self.subchip.generate_trace(vec![input]);
50        let inner_width = self.air.subair.width();
51        cols[..inner_width].copy_from_slice(inner_trace.values.as_slice());
52    }
53    #[allow(clippy::too_many_arguments)]
54    fn incorporate_sibling_record_to_row(
55        &self,
56        record: &IncorporateSiblingRecord<F>,
57        aux_cols_factory: &MemoryAuxColsFactory<F>,
58        slice: &mut [F],
59        memory: &OfflineMemory<F>,
60        parent: &VerifyBatchRecord<F>,
61        proof_index: usize,
62        opened_index: usize,
63        log_height: usize,
64    ) {
65        let &IncorporateSiblingRecord {
66            read_sibling_is_on_right,
67            sibling_is_on_right,
68            p2_input,
69        } = record;
70
71        let read_sibling_is_on_right = memory.record_by_id(read_sibling_is_on_right);
72
73        self.generate_subair_cols(p2_input, slice);
74        let cols: &mut NativePoseidon2Cols<F, SBOX_REGISTERS> = slice.borrow_mut();
75        cols.incorporate_row = F::ZERO;
76        cols.incorporate_sibling = F::ONE;
77        cols.inside_row = F::ZERO;
78        cols.simple = F::ZERO;
79        cols.end_inside_row = F::ZERO;
80        cols.end_top_level = F::ZERO;
81        cols.start_top_level = F::ZERO;
82        cols.opened_element_size_inv = parent.opened_element_size_inv();
83        cols.very_first_timestamp = F::from_canonical_u32(parent.from_state.timestamp);
84        cols.start_timestamp =
85            F::from_canonical_u32(read_sibling_is_on_right.timestamp - NUM_INITIAL_READS as u32);
86
87        let specific: &mut TopLevelSpecificCols<F> =
88            cols.specific[..TopLevelSpecificCols::<F>::width()].borrow_mut();
89
90        specific.end_timestamp =
91            F::from_canonical_usize(read_sibling_is_on_right.timestamp as usize + 1);
92        cols.initial_opened_index = F::from_canonical_usize(opened_index);
93        specific.final_opened_index = F::from_canonical_usize(opened_index - 1);
94        specific.log_height = F::from_canonical_usize(log_height);
95        specific.opened_length = F::from_canonical_usize(parent.opened_length);
96        specific.dim_base_pointer = parent.dim_base_pointer;
97        cols.opened_base_pointer = parent.opened_base_pointer;
98        specific.index_base_pointer = parent.index_base_pointer;
99
100        specific.proof_index = F::from_canonical_usize(proof_index);
101        aux_cols_factory.generate_read_aux(
102            read_sibling_is_on_right,
103            &mut specific.read_initial_height_or_sibling_is_on_right,
104        );
105        specific.sibling_is_on_right = F::from_bool(sibling_is_on_right);
106    }
107    fn correct_last_top_level_row(
108        &self,
109        record: &VerifyBatchRecord<F>,
110        aux_cols_factory: &MemoryAuxColsFactory<F>,
111        slice: &mut [F],
112        memory: &OfflineMemory<F>,
113    ) {
114        let &VerifyBatchRecord {
115            from_state,
116            commit_pointer,
117            dim_base_pointer_read,
118            opened_base_pointer_read,
119            opened_length_read,
120            index_base_pointer_read,
121            commit_pointer_read,
122            commit_read,
123            ..
124        } = record;
125        let instruction = &record.instruction;
126        let cols: &mut NativePoseidon2Cols<F, SBOX_REGISTERS> = slice.borrow_mut();
127        cols.end_top_level = F::ONE;
128
129        let specific: &mut TopLevelSpecificCols<F> =
130            cols.specific[..TopLevelSpecificCols::<F>::width()].borrow_mut();
131
132        specific.pc = F::from_canonical_u32(from_state.pc);
133        specific.dim_register = instruction.a;
134        specific.opened_register = instruction.b;
135        specific.opened_length_register = instruction.c;
136        specific.proof_id = instruction.d;
137        specific.index_register = instruction.e;
138        specific.commit_register = instruction.f;
139        specific.commit_pointer = commit_pointer;
140        aux_cols_factory.generate_read_aux(
141            memory.record_by_id(dim_base_pointer_read),
142            &mut specific.dim_base_pointer_read,
143        );
144        aux_cols_factory.generate_read_aux(
145            memory.record_by_id(opened_base_pointer_read),
146            &mut specific.opened_base_pointer_read,
147        );
148        aux_cols_factory.generate_read_aux(
149            memory.record_by_id(opened_length_read),
150            &mut specific.opened_length_read,
151        );
152        aux_cols_factory.generate_read_aux(
153            memory.record_by_id(index_base_pointer_read),
154            &mut specific.index_base_pointer_read,
155        );
156        aux_cols_factory.generate_read_aux(
157            memory.record_by_id(commit_pointer_read),
158            &mut specific.commit_pointer_read,
159        );
160        aux_cols_factory
161            .generate_read_aux(memory.record_by_id(commit_read), &mut specific.commit_read);
162    }
163    #[allow(clippy::too_many_arguments)]
164    fn incorporate_row_record_to_row(
165        &self,
166        record: &IncorporateRowRecord<F>,
167        aux_cols_factory: &MemoryAuxColsFactory<F>,
168        slice: &mut [F],
169        memory: &OfflineMemory<F>,
170        parent: &VerifyBatchRecord<F>,
171        proof_index: usize,
172        log_height: usize,
173    ) {
174        let &IncorporateRowRecord {
175            initial_opened_index,
176            final_opened_index,
177            initial_height_read,
178            final_height_read,
179            p2_input,
180            ..
181        } = record;
182
183        let initial_height_read = memory.record_by_id(initial_height_read);
184        let final_height_read = memory.record_by_id(final_height_read);
185
186        self.generate_subair_cols(p2_input, slice);
187        let cols: &mut NativePoseidon2Cols<F, SBOX_REGISTERS> = slice.borrow_mut();
188        cols.incorporate_row = F::ONE;
189        cols.incorporate_sibling = F::ZERO;
190        cols.inside_row = F::ZERO;
191        cols.simple = F::ZERO;
192        cols.end_inside_row = F::ZERO;
193        cols.end_top_level = F::ZERO;
194        cols.start_top_level = F::from_bool(proof_index == 0);
195        cols.opened_element_size_inv = parent.opened_element_size_inv();
196        cols.very_first_timestamp = F::from_canonical_u32(parent.from_state.timestamp);
197        cols.start_timestamp = F::from_canonical_u32(
198            memory
199                .record_by_id(
200                    record.chunks[0].cells[0]
201                        .read_row_pointer_and_length
202                        .unwrap(),
203                )
204                .timestamp
205                - NUM_INITIAL_READS as u32,
206        );
207        let specific: &mut TopLevelSpecificCols<F> =
208            cols.specific[..TopLevelSpecificCols::<F>::width()].borrow_mut();
209
210        specific.end_timestamp = F::from_canonical_u32(final_height_read.timestamp + 1);
211
212        cols.initial_opened_index = F::from_canonical_usize(initial_opened_index);
213        specific.final_opened_index = F::from_canonical_usize(final_opened_index);
214        specific.log_height = F::from_canonical_usize(log_height);
215        specific.opened_length = F::from_canonical_usize(parent.opened_length);
216        specific.dim_base_pointer = parent.dim_base_pointer;
217        cols.opened_base_pointer = parent.opened_base_pointer;
218        specific.index_base_pointer = parent.index_base_pointer;
219
220        specific.proof_index = F::from_canonical_usize(proof_index);
221        aux_cols_factory.generate_read_aux(
222            initial_height_read,
223            &mut specific.read_initial_height_or_sibling_is_on_right,
224        );
225        aux_cols_factory.generate_read_aux(final_height_read, &mut specific.read_final_height);
226    }
227    #[allow(clippy::too_many_arguments)]
228    fn inside_row_record_to_row(
229        &self,
230        record: &InsideRowRecord<F>,
231        aux_cols_factory: &MemoryAuxColsFactory<F>,
232        slice: &mut [F],
233        memory: &OfflineMemory<F>,
234        parent: &IncorporateRowRecord<F>,
235        grandparent: &VerifyBatchRecord<F>,
236        is_last: bool,
237    ) {
238        let InsideRowRecord { cells, p2_input } = record;
239
240        self.generate_subair_cols(*p2_input, slice);
241        let cols: &mut NativePoseidon2Cols<F, SBOX_REGISTERS> = slice.borrow_mut();
242        cols.incorporate_row = F::ZERO;
243        cols.incorporate_sibling = F::ZERO;
244        cols.inside_row = F::ONE;
245        cols.simple = F::ZERO;
246        cols.end_inside_row = F::from_bool(is_last);
247        cols.end_top_level = F::ZERO;
248        cols.opened_element_size_inv = grandparent.opened_element_size_inv();
249        cols.very_first_timestamp = F::from_canonical_u32(
250            memory
251                .record_by_id(
252                    parent.chunks[0].cells[0]
253                        .read_row_pointer_and_length
254                        .unwrap(),
255                )
256                .timestamp,
257        );
258        cols.start_timestamp =
259            F::from_canonical_u32(memory.record_by_id(cells[0].read).timestamp - 1);
260        let specific: &mut InsideRowSpecificCols<F> =
261            cols.specific[..InsideRowSpecificCols::<F>::width()].borrow_mut();
262
263        for (record, cell) in cells.iter().zip(specific.cells.iter_mut()) {
264            let &CellRecord {
265                read,
266                opened_index,
267                read_row_pointer_and_length,
268                row_pointer,
269                row_end,
270            } = record;
271            aux_cols_factory.generate_read_aux(memory.record_by_id(read), &mut cell.read);
272            cell.opened_index = F::from_canonical_usize(opened_index);
273            if let Some(read_row_pointer_and_length) = read_row_pointer_and_length {
274                aux_cols_factory.generate_read_aux(
275                    memory.record_by_id(read_row_pointer_and_length),
276                    &mut cell.read_row_pointer_and_length,
277                );
278            }
279            cell.row_pointer = F::from_canonical_usize(row_pointer);
280            cell.row_end = F::from_canonical_usize(row_end);
281            cell.is_first_in_row = F::from_bool(read_row_pointer_and_length.is_some());
282        }
283
284        for cell in specific.cells.iter_mut().skip(cells.len()) {
285            cell.opened_index = F::from_canonical_usize(parent.final_opened_index);
286        }
287
288        cols.is_exhausted = std::array::from_fn(|i| F::from_bool(i + 1 >= cells.len()));
289
290        cols.initial_opened_index = F::from_canonical_usize(parent.initial_opened_index);
291        cols.opened_base_pointer = grandparent.opened_base_pointer;
292    }
293    // returns number of used cells
294    fn verify_batch_record_to_rows(
295        &self,
296        record: &VerifyBatchRecord<F>,
297        aux_cols_factory: &MemoryAuxColsFactory<F>,
298        slice: &mut [F],
299        memory: &OfflineMemory<F>,
300    ) -> usize {
301        let width = NativePoseidon2Cols::<F, SBOX_REGISTERS>::width();
302        let mut used_cells = 0;
303
304        let mut opened_index = 0;
305        for (proof_index, top_level) in record.top_level.iter().enumerate() {
306            let log_height = record.initial_log_height - proof_index;
307            if let Some(incorporate_row) = &top_level.incorporate_row {
308                self.incorporate_row_record_to_row(
309                    incorporate_row,
310                    aux_cols_factory,
311                    &mut slice[used_cells..used_cells + width],
312                    memory,
313                    record,
314                    proof_index,
315                    log_height,
316                );
317                opened_index = incorporate_row.final_opened_index + 1;
318                used_cells += width;
319            }
320            if let Some(incorporate_sibling) = &top_level.incorporate_sibling {
321                self.incorporate_sibling_record_to_row(
322                    incorporate_sibling,
323                    aux_cols_factory,
324                    &mut slice[used_cells..used_cells + width],
325                    memory,
326                    record,
327                    proof_index,
328                    opened_index,
329                    log_height,
330                );
331                used_cells += width;
332            }
333        }
334        self.correct_last_top_level_row(
335            record,
336            aux_cols_factory,
337            &mut slice[used_cells - width..used_cells],
338            memory,
339        );
340
341        for top_level in record.top_level.iter() {
342            if let Some(incorporate_row) = &top_level.incorporate_row {
343                for (i, chunk) in incorporate_row.chunks.iter().enumerate() {
344                    self.inside_row_record_to_row(
345                        chunk,
346                        aux_cols_factory,
347                        &mut slice[used_cells..used_cells + width],
348                        memory,
349                        incorporate_row,
350                        record,
351                        i == incorporate_row.chunks.len() - 1,
352                    );
353                    used_cells += width;
354                }
355            }
356        }
357
358        used_cells
359    }
360    fn simple_record_to_row(
361        &self,
362        record: &SimplePoseidonRecord<F>,
363        aux_cols_factory: &MemoryAuxColsFactory<F>,
364        slice: &mut [F],
365        memory: &OfflineMemory<F>,
366    ) {
367        let &SimplePoseidonRecord {
368            from_state,
369            instruction:
370                Instruction {
371                    opcode,
372                    a: output_register,
373                    b: input_register_1,
374                    c: input_register_2,
375                    ..
376                },
377            read_input_pointer_1,
378            read_input_pointer_2,
379            read_output_pointer,
380            read_data_1,
381            read_data_2,
382            write_data_1,
383            write_data_2,
384            input_pointer_1,
385            input_pointer_2,
386            output_pointer,
387            p2_input,
388        } = record;
389
390        let read_input_pointer_1 = memory.record_by_id(read_input_pointer_1);
391        let read_output_pointer = memory.record_by_id(read_output_pointer);
392        let read_data_1 = memory.record_by_id(read_data_1);
393        let read_data_2 = memory.record_by_id(read_data_2);
394        let write_data_1 = memory.record_by_id(write_data_1);
395
396        self.generate_subair_cols(p2_input, slice);
397        let cols: &mut NativePoseidon2Cols<F, SBOX_REGISTERS> = slice.borrow_mut();
398        cols.incorporate_row = F::ZERO;
399        cols.incorporate_sibling = F::ZERO;
400        cols.inside_row = F::ZERO;
401        cols.simple = F::ONE;
402        cols.end_inside_row = F::ZERO;
403        cols.end_top_level = F::ZERO;
404        cols.is_exhausted = [F::ZERO; CHUNK - 1];
405
406        cols.start_timestamp = F::from_canonical_u32(from_state.timestamp);
407        let specific: &mut SimplePoseidonSpecificCols<F> =
408            cols.specific[..SimplePoseidonSpecificCols::<F>::width()].borrow_mut();
409
410        specific.pc = F::from_canonical_u32(from_state.pc);
411        specific.is_compress = F::from_bool(opcode == COMP_POS2.global_opcode());
412        specific.output_register = output_register;
413        specific.input_register_1 = input_register_1;
414        specific.input_register_2 = input_register_2;
415        specific.output_pointer = output_pointer;
416        specific.input_pointer_1 = input_pointer_1;
417        specific.input_pointer_2 = input_pointer_2;
418        aux_cols_factory.generate_read_aux(read_output_pointer, &mut specific.read_output_pointer);
419        aux_cols_factory
420            .generate_read_aux(read_input_pointer_1, &mut specific.read_input_pointer_1);
421        aux_cols_factory.generate_read_aux(read_data_1, &mut specific.read_data_1);
422        aux_cols_factory.generate_read_aux(read_data_2, &mut specific.read_data_2);
423        aux_cols_factory.generate_write_aux(write_data_1, &mut specific.write_data_1);
424
425        if opcode == COMP_POS2.global_opcode() {
426            let read_input_pointer_2 = memory.record_by_id(read_input_pointer_2.unwrap());
427            aux_cols_factory
428                .generate_read_aux(read_input_pointer_2, &mut specific.read_input_pointer_2);
429        } else {
430            let write_data_2 = memory.record_by_id(write_data_2.unwrap());
431            aux_cols_factory.generate_write_aux(write_data_2, &mut specific.write_data_2);
432        }
433    }
434
435    fn generate_trace(self) -> RowMajorMatrix<F> {
436        let width = self.trace_width();
437        let height = next_power_of_two_or_zero(self.height);
438        let mut flat_trace = F::zero_vec(width * height);
439
440        let memory = self.offline_memory.lock().unwrap();
441
442        let aux_cols_factory = memory.aux_cols_factory();
443
444        let mut used_cells = 0;
445        for record in self.record_set.verify_batch_records.iter() {
446            used_cells += self.verify_batch_record_to_rows(
447                record,
448                &aux_cols_factory,
449                &mut flat_trace[used_cells..],
450                &memory,
451            );
452        }
453        for record in self.record_set.simple_permute_records.iter() {
454            self.simple_record_to_row(
455                record,
456                &aux_cols_factory,
457                &mut flat_trace[used_cells..used_cells + width],
458                &memory,
459            );
460            used_cells += width;
461        }
462        // poseidon2 constraints are always checked
463        // following can be optimized to only hash [0; _] once
464        flat_trace[used_cells..]
465            .par_chunks_mut(width)
466            .for_each(|row| {
467                self.generate_subair_cols([F::ZERO; 2 * CHUNK], row);
468            });
469
470        RowMajorMatrix::new(flat_trace, width)
471    }
472}
473
474impl<SC: StarkGenericConfig, const SBOX_REGISTERS: usize> Chip<SC>
475    for NativePoseidon2Chip<Val<SC>, SBOX_REGISTERS>
476where
477    Val<SC>: PrimeField32,
478{
479    fn air(&self) -> AirRef<SC> {
480        Arc::new(self.air.clone())
481    }
482    fn generate_air_proof_input(self) -> AirProofInput<SC> {
483        AirProofInput::simple_no_pis(self.generate_trace())
484    }
485}