1use std::{borrow::BorrowMut, sync::Arc};
2
3use openvm_circuit::system::memory::{MemoryAuxColsFactory, OfflineMemory};
4use openvm_circuit_primitives::utils::next_power_of_two_or_zero;
5use openvm_instructions::{instruction::Instruction, LocalOpcode};
6use openvm_native_compiler::Poseidon2Opcode::COMP_POS2;
7use openvm_stark_backend::{
8 config::{StarkGenericConfig, Val},
9 p3_air::BaseAir,
10 p3_field::{Field, PrimeField32},
11 p3_matrix::dense::RowMajorMatrix,
12 p3_maybe_rayon::prelude::*,
13 prover::types::AirProofInput,
14 AirRef, Chip, ChipUsageGetter,
15};
16
17use crate::{
18 chip::{SimplePoseidonRecord, NUM_INITIAL_READS},
19 poseidon2::{
20 chip::{
21 CellRecord, IncorporateRowRecord, IncorporateSiblingRecord, InsideRowRecord,
22 NativePoseidon2Chip, VerifyBatchRecord,
23 },
24 columns::{
25 InsideRowSpecificCols, NativePoseidon2Cols, SimplePoseidonSpecificCols,
26 TopLevelSpecificCols,
27 },
28 CHUNK,
29 },
30};
31impl<F: Field, const SBOX_REGISTERS: usize> ChipUsageGetter
32 for NativePoseidon2Chip<F, SBOX_REGISTERS>
33{
34 fn air_name(&self) -> String {
35 "VerifyBatchAir".to_string()
36 }
37
38 fn current_trace_height(&self) -> usize {
39 self.height
40 }
41
42 fn trace_width(&self) -> usize {
43 NativePoseidon2Cols::<F, SBOX_REGISTERS>::width()
44 }
45}
46
47impl<F: PrimeField32, const SBOX_REGISTERS: usize> NativePoseidon2Chip<F, SBOX_REGISTERS> {
48 fn generate_subair_cols(&self, input: [F; 2 * CHUNK], cols: &mut [F]) {
49 let inner_trace = self.subchip.generate_trace(vec![input]);
50 let inner_width = self.air.subair.width();
51 cols[..inner_width].copy_from_slice(inner_trace.values.as_slice());
52 }
53 #[allow(clippy::too_many_arguments)]
54 fn incorporate_sibling_record_to_row(
55 &self,
56 record: &IncorporateSiblingRecord<F>,
57 aux_cols_factory: &MemoryAuxColsFactory<F>,
58 slice: &mut [F],
59 memory: &OfflineMemory<F>,
60 parent: &VerifyBatchRecord<F>,
61 proof_index: usize,
62 opened_index: usize,
63 log_height: usize,
64 ) {
65 let &IncorporateSiblingRecord {
66 read_sibling_is_on_right,
67 sibling_is_on_right,
68 p2_input,
69 } = record;
70
71 let read_sibling_is_on_right = memory.record_by_id(read_sibling_is_on_right);
72
73 self.generate_subair_cols(p2_input, slice);
74 let cols: &mut NativePoseidon2Cols<F, SBOX_REGISTERS> = slice.borrow_mut();
75 cols.incorporate_row = F::ZERO;
76 cols.incorporate_sibling = F::ONE;
77 cols.inside_row = F::ZERO;
78 cols.simple = F::ZERO;
79 cols.end_inside_row = F::ZERO;
80 cols.end_top_level = F::ZERO;
81 cols.start_top_level = F::ZERO;
82 cols.opened_element_size_inv = parent.opened_element_size_inv();
83 cols.very_first_timestamp = F::from_canonical_u32(parent.from_state.timestamp);
84 cols.start_timestamp =
85 F::from_canonical_u32(read_sibling_is_on_right.timestamp - NUM_INITIAL_READS as u32);
86
87 let specific: &mut TopLevelSpecificCols<F> =
88 cols.specific[..TopLevelSpecificCols::<F>::width()].borrow_mut();
89
90 specific.end_timestamp =
91 F::from_canonical_usize(read_sibling_is_on_right.timestamp as usize + 1);
92 cols.initial_opened_index = F::from_canonical_usize(opened_index);
93 specific.final_opened_index = F::from_canonical_usize(opened_index - 1);
94 specific.log_height = F::from_canonical_usize(log_height);
95 specific.opened_length = F::from_canonical_usize(parent.opened_length);
96 specific.dim_base_pointer = parent.dim_base_pointer;
97 cols.opened_base_pointer = parent.opened_base_pointer;
98 specific.index_base_pointer = parent.index_base_pointer;
99
100 specific.proof_index = F::from_canonical_usize(proof_index);
101 aux_cols_factory.generate_read_aux(
102 read_sibling_is_on_right,
103 &mut specific.read_initial_height_or_sibling_is_on_right,
104 );
105 specific.sibling_is_on_right = F::from_bool(sibling_is_on_right);
106 }
107 fn correct_last_top_level_row(
108 &self,
109 record: &VerifyBatchRecord<F>,
110 aux_cols_factory: &MemoryAuxColsFactory<F>,
111 slice: &mut [F],
112 memory: &OfflineMemory<F>,
113 ) {
114 let &VerifyBatchRecord {
115 from_state,
116 commit_pointer,
117 dim_base_pointer_read,
118 opened_base_pointer_read,
119 opened_length_read,
120 index_base_pointer_read,
121 commit_pointer_read,
122 commit_read,
123 ..
124 } = record;
125 let instruction = &record.instruction;
126 let cols: &mut NativePoseidon2Cols<F, SBOX_REGISTERS> = slice.borrow_mut();
127 cols.end_top_level = F::ONE;
128
129 let specific: &mut TopLevelSpecificCols<F> =
130 cols.specific[..TopLevelSpecificCols::<F>::width()].borrow_mut();
131
132 specific.pc = F::from_canonical_u32(from_state.pc);
133 specific.dim_register = instruction.a;
134 specific.opened_register = instruction.b;
135 specific.opened_length_register = instruction.c;
136 specific.proof_id = instruction.d;
137 specific.index_register = instruction.e;
138 specific.commit_register = instruction.f;
139 specific.commit_pointer = commit_pointer;
140 aux_cols_factory.generate_read_aux(
141 memory.record_by_id(dim_base_pointer_read),
142 &mut specific.dim_base_pointer_read,
143 );
144 aux_cols_factory.generate_read_aux(
145 memory.record_by_id(opened_base_pointer_read),
146 &mut specific.opened_base_pointer_read,
147 );
148 aux_cols_factory.generate_read_aux(
149 memory.record_by_id(opened_length_read),
150 &mut specific.opened_length_read,
151 );
152 aux_cols_factory.generate_read_aux(
153 memory.record_by_id(index_base_pointer_read),
154 &mut specific.index_base_pointer_read,
155 );
156 aux_cols_factory.generate_read_aux(
157 memory.record_by_id(commit_pointer_read),
158 &mut specific.commit_pointer_read,
159 );
160 aux_cols_factory
161 .generate_read_aux(memory.record_by_id(commit_read), &mut specific.commit_read);
162 }
163 #[allow(clippy::too_many_arguments)]
164 fn incorporate_row_record_to_row(
165 &self,
166 record: &IncorporateRowRecord<F>,
167 aux_cols_factory: &MemoryAuxColsFactory<F>,
168 slice: &mut [F],
169 memory: &OfflineMemory<F>,
170 parent: &VerifyBatchRecord<F>,
171 proof_index: usize,
172 log_height: usize,
173 ) {
174 let &IncorporateRowRecord {
175 initial_opened_index,
176 final_opened_index,
177 initial_height_read,
178 final_height_read,
179 p2_input,
180 ..
181 } = record;
182
183 let initial_height_read = memory.record_by_id(initial_height_read);
184 let final_height_read = memory.record_by_id(final_height_read);
185
186 self.generate_subair_cols(p2_input, slice);
187 let cols: &mut NativePoseidon2Cols<F, SBOX_REGISTERS> = slice.borrow_mut();
188 cols.incorporate_row = F::ONE;
189 cols.incorporate_sibling = F::ZERO;
190 cols.inside_row = F::ZERO;
191 cols.simple = F::ZERO;
192 cols.end_inside_row = F::ZERO;
193 cols.end_top_level = F::ZERO;
194 cols.start_top_level = F::from_bool(proof_index == 0);
195 cols.opened_element_size_inv = parent.opened_element_size_inv();
196 cols.very_first_timestamp = F::from_canonical_u32(parent.from_state.timestamp);
197 cols.start_timestamp = F::from_canonical_u32(
198 memory
199 .record_by_id(
200 record.chunks[0].cells[0]
201 .read_row_pointer_and_length
202 .unwrap(),
203 )
204 .timestamp
205 - NUM_INITIAL_READS as u32,
206 );
207 let specific: &mut TopLevelSpecificCols<F> =
208 cols.specific[..TopLevelSpecificCols::<F>::width()].borrow_mut();
209
210 specific.end_timestamp = F::from_canonical_u32(final_height_read.timestamp + 1);
211
212 cols.initial_opened_index = F::from_canonical_usize(initial_opened_index);
213 specific.final_opened_index = F::from_canonical_usize(final_opened_index);
214 specific.log_height = F::from_canonical_usize(log_height);
215 specific.opened_length = F::from_canonical_usize(parent.opened_length);
216 specific.dim_base_pointer = parent.dim_base_pointer;
217 cols.opened_base_pointer = parent.opened_base_pointer;
218 specific.index_base_pointer = parent.index_base_pointer;
219
220 specific.proof_index = F::from_canonical_usize(proof_index);
221 aux_cols_factory.generate_read_aux(
222 initial_height_read,
223 &mut specific.read_initial_height_or_sibling_is_on_right,
224 );
225 aux_cols_factory.generate_read_aux(final_height_read, &mut specific.read_final_height);
226 }
227 #[allow(clippy::too_many_arguments)]
228 fn inside_row_record_to_row(
229 &self,
230 record: &InsideRowRecord<F>,
231 aux_cols_factory: &MemoryAuxColsFactory<F>,
232 slice: &mut [F],
233 memory: &OfflineMemory<F>,
234 parent: &IncorporateRowRecord<F>,
235 grandparent: &VerifyBatchRecord<F>,
236 is_last: bool,
237 ) {
238 let InsideRowRecord { cells, p2_input } = record;
239
240 self.generate_subair_cols(*p2_input, slice);
241 let cols: &mut NativePoseidon2Cols<F, SBOX_REGISTERS> = slice.borrow_mut();
242 cols.incorporate_row = F::ZERO;
243 cols.incorporate_sibling = F::ZERO;
244 cols.inside_row = F::ONE;
245 cols.simple = F::ZERO;
246 cols.end_inside_row = F::from_bool(is_last);
247 cols.end_top_level = F::ZERO;
248 cols.opened_element_size_inv = grandparent.opened_element_size_inv();
249 cols.very_first_timestamp = F::from_canonical_u32(
250 memory
251 .record_by_id(
252 parent.chunks[0].cells[0]
253 .read_row_pointer_and_length
254 .unwrap(),
255 )
256 .timestamp,
257 );
258 cols.start_timestamp =
259 F::from_canonical_u32(memory.record_by_id(cells[0].read).timestamp - 1);
260 let specific: &mut InsideRowSpecificCols<F> =
261 cols.specific[..InsideRowSpecificCols::<F>::width()].borrow_mut();
262
263 for (record, cell) in cells.iter().zip(specific.cells.iter_mut()) {
264 let &CellRecord {
265 read,
266 opened_index,
267 read_row_pointer_and_length,
268 row_pointer,
269 row_end,
270 } = record;
271 aux_cols_factory.generate_read_aux(memory.record_by_id(read), &mut cell.read);
272 cell.opened_index = F::from_canonical_usize(opened_index);
273 if let Some(read_row_pointer_and_length) = read_row_pointer_and_length {
274 aux_cols_factory.generate_read_aux(
275 memory.record_by_id(read_row_pointer_and_length),
276 &mut cell.read_row_pointer_and_length,
277 );
278 }
279 cell.row_pointer = F::from_canonical_usize(row_pointer);
280 cell.row_end = F::from_canonical_usize(row_end);
281 cell.is_first_in_row = F::from_bool(read_row_pointer_and_length.is_some());
282 }
283
284 for cell in specific.cells.iter_mut().skip(cells.len()) {
285 cell.opened_index = F::from_canonical_usize(parent.final_opened_index);
286 }
287
288 cols.is_exhausted = std::array::from_fn(|i| F::from_bool(i + 1 >= cells.len()));
289
290 cols.initial_opened_index = F::from_canonical_usize(parent.initial_opened_index);
291 cols.opened_base_pointer = grandparent.opened_base_pointer;
292 }
293 fn verify_batch_record_to_rows(
295 &self,
296 record: &VerifyBatchRecord<F>,
297 aux_cols_factory: &MemoryAuxColsFactory<F>,
298 slice: &mut [F],
299 memory: &OfflineMemory<F>,
300 ) -> usize {
301 let width = NativePoseidon2Cols::<F, SBOX_REGISTERS>::width();
302 let mut used_cells = 0;
303
304 let mut opened_index = 0;
305 for (proof_index, top_level) in record.top_level.iter().enumerate() {
306 let log_height = record.initial_log_height - proof_index;
307 if let Some(incorporate_row) = &top_level.incorporate_row {
308 self.incorporate_row_record_to_row(
309 incorporate_row,
310 aux_cols_factory,
311 &mut slice[used_cells..used_cells + width],
312 memory,
313 record,
314 proof_index,
315 log_height,
316 );
317 opened_index = incorporate_row.final_opened_index + 1;
318 used_cells += width;
319 }
320 if let Some(incorporate_sibling) = &top_level.incorporate_sibling {
321 self.incorporate_sibling_record_to_row(
322 incorporate_sibling,
323 aux_cols_factory,
324 &mut slice[used_cells..used_cells + width],
325 memory,
326 record,
327 proof_index,
328 opened_index,
329 log_height,
330 );
331 used_cells += width;
332 }
333 }
334 self.correct_last_top_level_row(
335 record,
336 aux_cols_factory,
337 &mut slice[used_cells - width..used_cells],
338 memory,
339 );
340
341 for top_level in record.top_level.iter() {
342 if let Some(incorporate_row) = &top_level.incorporate_row {
343 for (i, chunk) in incorporate_row.chunks.iter().enumerate() {
344 self.inside_row_record_to_row(
345 chunk,
346 aux_cols_factory,
347 &mut slice[used_cells..used_cells + width],
348 memory,
349 incorporate_row,
350 record,
351 i == incorporate_row.chunks.len() - 1,
352 );
353 used_cells += width;
354 }
355 }
356 }
357
358 used_cells
359 }
360 fn simple_record_to_row(
361 &self,
362 record: &SimplePoseidonRecord<F>,
363 aux_cols_factory: &MemoryAuxColsFactory<F>,
364 slice: &mut [F],
365 memory: &OfflineMemory<F>,
366 ) {
367 let &SimplePoseidonRecord {
368 from_state,
369 instruction:
370 Instruction {
371 opcode,
372 a: output_register,
373 b: input_register_1,
374 c: input_register_2,
375 ..
376 },
377 read_input_pointer_1,
378 read_input_pointer_2,
379 read_output_pointer,
380 read_data_1,
381 read_data_2,
382 write_data_1,
383 write_data_2,
384 input_pointer_1,
385 input_pointer_2,
386 output_pointer,
387 p2_input,
388 } = record;
389
390 let read_input_pointer_1 = memory.record_by_id(read_input_pointer_1);
391 let read_output_pointer = memory.record_by_id(read_output_pointer);
392 let read_data_1 = memory.record_by_id(read_data_1);
393 let read_data_2 = memory.record_by_id(read_data_2);
394 let write_data_1 = memory.record_by_id(write_data_1);
395
396 self.generate_subair_cols(p2_input, slice);
397 let cols: &mut NativePoseidon2Cols<F, SBOX_REGISTERS> = slice.borrow_mut();
398 cols.incorporate_row = F::ZERO;
399 cols.incorporate_sibling = F::ZERO;
400 cols.inside_row = F::ZERO;
401 cols.simple = F::ONE;
402 cols.end_inside_row = F::ZERO;
403 cols.end_top_level = F::ZERO;
404 cols.is_exhausted = [F::ZERO; CHUNK - 1];
405
406 cols.start_timestamp = F::from_canonical_u32(from_state.timestamp);
407 let specific: &mut SimplePoseidonSpecificCols<F> =
408 cols.specific[..SimplePoseidonSpecificCols::<F>::width()].borrow_mut();
409
410 specific.pc = F::from_canonical_u32(from_state.pc);
411 specific.is_compress = F::from_bool(opcode == COMP_POS2.global_opcode());
412 specific.output_register = output_register;
413 specific.input_register_1 = input_register_1;
414 specific.input_register_2 = input_register_2;
415 specific.output_pointer = output_pointer;
416 specific.input_pointer_1 = input_pointer_1;
417 specific.input_pointer_2 = input_pointer_2;
418 aux_cols_factory.generate_read_aux(read_output_pointer, &mut specific.read_output_pointer);
419 aux_cols_factory
420 .generate_read_aux(read_input_pointer_1, &mut specific.read_input_pointer_1);
421 aux_cols_factory.generate_read_aux(read_data_1, &mut specific.read_data_1);
422 aux_cols_factory.generate_read_aux(read_data_2, &mut specific.read_data_2);
423 aux_cols_factory.generate_write_aux(write_data_1, &mut specific.write_data_1);
424
425 if opcode == COMP_POS2.global_opcode() {
426 let read_input_pointer_2 = memory.record_by_id(read_input_pointer_2.unwrap());
427 aux_cols_factory
428 .generate_read_aux(read_input_pointer_2, &mut specific.read_input_pointer_2);
429 } else {
430 let write_data_2 = memory.record_by_id(write_data_2.unwrap());
431 aux_cols_factory.generate_write_aux(write_data_2, &mut specific.write_data_2);
432 }
433 }
434
435 fn generate_trace(self) -> RowMajorMatrix<F> {
436 let width = self.trace_width();
437 let height = next_power_of_two_or_zero(self.height);
438 let mut flat_trace = F::zero_vec(width * height);
439
440 let memory = self.offline_memory.lock().unwrap();
441
442 let aux_cols_factory = memory.aux_cols_factory();
443
444 let mut used_cells = 0;
445 for record in self.record_set.verify_batch_records.iter() {
446 used_cells += self.verify_batch_record_to_rows(
447 record,
448 &aux_cols_factory,
449 &mut flat_trace[used_cells..],
450 &memory,
451 );
452 }
453 for record in self.record_set.simple_permute_records.iter() {
454 self.simple_record_to_row(
455 record,
456 &aux_cols_factory,
457 &mut flat_trace[used_cells..used_cells + width],
458 &memory,
459 );
460 used_cells += width;
461 }
462 flat_trace[used_cells..]
465 .par_chunks_mut(width)
466 .for_each(|row| {
467 self.generate_subair_cols([F::ZERO; 2 * CHUNK], row);
468 });
469
470 RowMajorMatrix::new(flat_trace, width)
471 }
472}
473
474impl<SC: StarkGenericConfig, const SBOX_REGISTERS: usize> Chip<SC>
475 for NativePoseidon2Chip<Val<SC>, SBOX_REGISTERS>
476where
477 Val<SC>: PrimeField32,
478{
479 fn air(&self) -> AirRef<SC> {
480 Arc::new(self.air.clone())
481 }
482 fn generate_air_proof_input(self) -> AirProofInput<SC> {
483 AirProofInput::simple_no_pis(self.generate_trace())
484 }
485}