1use std::sync::{Arc, Mutex};
2
3use openvm_circuit::{
4 arch::{
5 ExecutionBridge, ExecutionError, ExecutionState, InstructionExecutor, Streams, SystemPort,
6 },
7 system::memory::{MemoryController, OfflineMemory, RecordId},
8};
9use openvm_instructions::{instruction::Instruction, program::DEFAULT_PC_STEP, LocalOpcode};
10use openvm_native_compiler::{
11 conversion::AS,
12 Poseidon2Opcode::{COMP_POS2, PERM_POS2},
13 VerifyBatchOpcode::VERIFY_BATCH,
14};
15use openvm_poseidon2_air::{Poseidon2Config, Poseidon2SubAir, Poseidon2SubChip};
16use openvm_stark_backend::{
17 p3_field::{Field, PrimeField32},
18 p3_maybe_rayon::prelude::{ParallelIterator, ParallelSlice},
19};
20use serde::{Deserialize, Serialize};
21
22use crate::poseidon2::{
23 air::{NativePoseidon2Air, VerifyBatchBus},
24 CHUNK,
25};
26
27#[derive(Debug, Clone, Serialize, Deserialize)]
28#[serde(bound = "F: Field")]
29pub struct VerifyBatchRecord<F: Field> {
30 pub from_state: ExecutionState<u32>,
31 pub instruction: Instruction<F>,
32
33 pub dim_base_pointer: F,
34 pub opened_base_pointer: F,
35 pub opened_length: usize,
36 pub index_base_pointer: F,
37 pub commit_pointer: F,
38
39 pub dim_base_pointer_read: RecordId,
40 pub opened_base_pointer_read: RecordId,
41 pub opened_length_read: RecordId,
42 pub index_base_pointer_read: RecordId,
43 pub commit_pointer_read: RecordId,
44
45 pub commit_read: RecordId,
46 pub initial_log_height: usize,
47 pub top_level: Vec<TopLevelRecord<F>>,
48}
49
50impl<F: PrimeField32> VerifyBatchRecord<F> {
51 pub fn opened_element_size_inv(&self) -> F {
52 self.instruction.g
53 }
54}
55
56#[derive(Debug, Clone, Serialize, Deserialize)]
57#[serde(bound = "F: Field")]
58pub struct TopLevelRecord<F: Field> {
59 pub incorporate_row: Option<IncorporateRowRecord<F>>,
61 pub incorporate_sibling: Option<IncorporateSiblingRecord<F>>,
63}
64
65#[repr(C)]
66#[derive(Debug, Clone, Serialize, Deserialize)]
67#[serde(bound = "F: Field")]
68pub struct IncorporateSiblingRecord<F: Field> {
69 pub read_sibling_is_on_right: RecordId,
70 pub sibling_is_on_right: bool,
71 pub p2_input: [F; 2 * CHUNK],
72}
73
74#[derive(Debug, Clone, Serialize, Deserialize)]
75#[serde(bound = "F: Field")]
76pub struct IncorporateRowRecord<F: Field> {
77 pub chunks: Vec<InsideRowRecord<F>>,
78 pub initial_opened_index: usize,
79 pub final_opened_index: usize,
80 pub initial_height_read: RecordId,
81 pub final_height_read: RecordId,
82 pub p2_input: [F; 2 * CHUNK],
83}
84
85#[derive(Debug, Clone, Serialize, Deserialize)]
86#[serde(bound = "F: Field")]
87pub struct InsideRowRecord<F: Field> {
88 pub cells: Vec<CellRecord>,
89 pub p2_input: [F; 2 * CHUNK],
90}
91
92#[repr(C)]
93#[derive(Debug, Clone, Serialize, Deserialize)]
94pub struct CellRecord {
95 pub read: RecordId,
96 pub opened_index: usize,
97 pub read_row_pointer_and_length: Option<RecordId>,
98 pub row_pointer: usize,
99 pub row_end: usize,
100}
101
102#[repr(C)]
103#[derive(Debug, Clone, Serialize, Deserialize)]
104#[serde(bound = "F: Field")]
105pub struct SimplePoseidonRecord<F: Field> {
106 pub from_state: ExecutionState<u32>,
107 pub instruction: Instruction<F>,
108
109 pub read_input_pointer_1: RecordId,
110 pub read_input_pointer_2: Option<RecordId>,
111 pub read_output_pointer: RecordId,
112 pub read_data_1: RecordId,
113 pub read_data_2: RecordId,
114 pub write_data_1: RecordId,
115 pub write_data_2: Option<RecordId>,
116
117 pub input_pointer_1: F,
118 pub input_pointer_2: F,
119 pub output_pointer: F,
120 pub p2_input: [F; 2 * CHUNK],
121}
122
123#[derive(Debug, Clone, Serialize, Deserialize, Default)]
124#[serde(bound = "F: Field")]
125pub struct NativePoseidon2RecordSet<F: Field> {
126 pub verify_batch_records: Vec<VerifyBatchRecord<F>>,
127 pub simple_permute_records: Vec<SimplePoseidonRecord<F>>,
128}
129
130pub struct NativePoseidon2Chip<F: Field, const SBOX_REGISTERS: usize> {
131 pub(super) air: NativePoseidon2Air<F, SBOX_REGISTERS>,
132 pub record_set: NativePoseidon2RecordSet<F>,
133 pub height: usize,
134 pub(super) offline_memory: Arc<Mutex<OfflineMemory<F>>>,
135 pub(super) subchip: Poseidon2SubChip<F, SBOX_REGISTERS>,
136 pub(super) streams: Arc<Mutex<Streams<F>>>,
137}
138
139impl<F: PrimeField32, const SBOX_REGISTERS: usize> NativePoseidon2Chip<F, SBOX_REGISTERS> {
140 pub fn new(
141 port: SystemPort,
142 offline_memory: Arc<Mutex<OfflineMemory<F>>>,
143 poseidon2_config: Poseidon2Config<F>,
144 verify_batch_bus: VerifyBatchBus,
145 streams: Arc<Mutex<Streams<F>>>,
146 ) -> Self {
147 let air = NativePoseidon2Air {
148 execution_bridge: ExecutionBridge::new(port.execution_bus, port.program_bus),
149 memory_bridge: port.memory_bridge,
150 internal_bus: verify_batch_bus,
151 subair: Arc::new(Poseidon2SubAir::new(poseidon2_config.constants.into())),
152 address_space: F::from_canonical_u32(AS::Native as u32),
153 };
154 Self {
155 record_set: Default::default(),
156 air,
157 height: 0,
158 offline_memory,
159 subchip: Poseidon2SubChip::new(poseidon2_config.constants),
160 streams,
161 }
162 }
163
164 fn compress(&self, left: [F; CHUNK], right: [F; CHUNK]) -> ([F; 2 * CHUNK], [F; CHUNK]) {
165 let concatenated =
166 std::array::from_fn(|i| if i < CHUNK { left[i] } else { right[i - CHUNK] });
167 let permuted = self.subchip.permute(concatenated);
168 (concatenated, std::array::from_fn(|i| permuted[i]))
169 }
170}
171
172pub(super) const NUM_INITIAL_READS: usize = 6;
173pub(super) const NUM_SIMPLE_ACCESSES: u32 = 7;
174
175impl<F: PrimeField32, const SBOX_REGISTERS: usize> InstructionExecutor<F>
176 for NativePoseidon2Chip<F, SBOX_REGISTERS>
177{
178 fn execute(
179 &mut self,
180 memory: &mut MemoryController<F>,
181 instruction: &Instruction<F>,
182 from_state: ExecutionState<u32>,
183 ) -> Result<ExecutionState<u32>, ExecutionError> {
184 if instruction.opcode == PERM_POS2.global_opcode()
185 || instruction.opcode == COMP_POS2.global_opcode()
186 {
187 let &Instruction {
188 a: output_register,
189 b: input_register_1,
190 c: input_register_2,
191 d: register_address_space,
192 e: data_address_space,
193 ..
194 } = instruction;
195
196 let (read_output_pointer, output_pointer) =
197 memory.read_cell(register_address_space, output_register);
198 let (read_input_pointer_1, input_pointer_1) =
199 memory.read_cell(register_address_space, input_register_1);
200 let (read_input_pointer_2, input_pointer_2) =
201 if instruction.opcode == PERM_POS2.global_opcode() {
202 memory.increment_timestamp();
203 (None, input_pointer_1 + F::from_canonical_usize(CHUNK))
204 } else {
205 let (read_input_pointer_2, input_pointer_2) =
206 memory.read_cell(register_address_space, input_register_2);
207 (Some(read_input_pointer_2), input_pointer_2)
208 };
209 let (read_data_1, data_1) = memory.read::<CHUNK>(data_address_space, input_pointer_1);
210 let (read_data_2, data_2) = memory.read::<CHUNK>(data_address_space, input_pointer_2);
211 let p2_input = std::array::from_fn(|i| {
212 if i < CHUNK {
213 data_1[i]
214 } else {
215 data_2[i - CHUNK]
216 }
217 });
218 let output = self.subchip.permute(p2_input);
219 let (write_data_1, _) = memory.write::<CHUNK>(
220 data_address_space,
221 output_pointer,
222 std::array::from_fn(|i| output[i]),
223 );
224 let write_data_2 = if instruction.opcode == PERM_POS2.global_opcode() {
225 Some(
226 memory
227 .write::<CHUNK>(
228 data_address_space,
229 output_pointer + F::from_canonical_usize(CHUNK),
230 std::array::from_fn(|i| output[CHUNK + i]),
231 )
232 .0,
233 )
234 } else {
235 memory.increment_timestamp();
236 None
237 };
238
239 assert_eq!(
240 memory.timestamp(),
241 from_state.timestamp + NUM_SIMPLE_ACCESSES
242 );
243
244 self.record_set
245 .simple_permute_records
246 .push(SimplePoseidonRecord {
247 from_state,
248 instruction: instruction.clone(),
249 read_input_pointer_1,
250 read_input_pointer_2,
251 read_output_pointer,
252 read_data_1,
253 read_data_2,
254 write_data_1,
255 write_data_2,
256 input_pointer_1,
257 input_pointer_2,
258 output_pointer,
259 p2_input,
260 });
261 self.height += 1;
262 } else if instruction.opcode == VERIFY_BATCH.global_opcode() {
263 let &Instruction {
264 a: dim_register,
265 b: opened_register,
266 c: opened_length_register,
267 d: proof_id_ptr,
268 e: index_register,
269 f: commit_register,
270 g: opened_element_size_inv,
271 ..
272 } = instruction;
273 let address_space = self.air.address_space;
274 let mut opened_element_size = F::ONE;
276 while opened_element_size * opened_element_size_inv != F::ONE {
277 opened_element_size += F::ONE;
278 }
279
280 let proof_id = memory.unsafe_read_cell(address_space, proof_id_ptr);
281 let (dim_base_pointer_read, dim_base_pointer) =
282 memory.read_cell(address_space, dim_register);
283 let (opened_base_pointer_read, opened_base_pointer) =
284 memory.read_cell(address_space, opened_register);
285 let (opened_length_read, opened_length) =
286 memory.read_cell(address_space, opened_length_register);
287 let (index_base_pointer_read, index_base_pointer) =
288 memory.read_cell(address_space, index_register);
289 let (commit_pointer_read, commit_pointer) =
290 memory.read_cell(address_space, commit_register);
291 let (commit_read, commit) = memory.read(address_space, commit_pointer);
292
293 let opened_length = opened_length.as_canonical_u32() as usize;
294
295 let initial_log_height = memory
296 .unsafe_read_cell(address_space, dim_base_pointer)
297 .as_canonical_u32();
298 let mut log_height = initial_log_height as i32;
299 let mut sibling_index = 0;
300 let mut opened_index = 0;
301 let mut top_level = vec![];
302
303 let mut root = [F::ZERO; CHUNK];
304 let sibling_proof: Vec<[F; CHUNK]> = {
305 let streams = self.streams.lock().unwrap();
306 let proof_idx = proof_id.as_canonical_u32() as usize;
307 streams.hint_space[proof_idx]
308 .par_chunks(CHUNK)
309 .map(|c| c.try_into().unwrap())
310 .collect()
311 };
312
313 while log_height >= 0 {
314 let incorporate_row = if opened_index < opened_length
315 && memory.unsafe_read_cell(
316 address_space,
317 dim_base_pointer + F::from_canonical_usize(opened_index),
318 ) == F::from_canonical_u32(log_height as u32)
319 {
320 let initial_opened_index = opened_index;
321 for _ in 0..NUM_INITIAL_READS {
322 memory.increment_timestamp();
323 }
324 let mut chunks = vec![];
325
326 let mut row_pointer = 0;
327 let mut row_end = 0;
328
329 let mut prev_rolling_hash: Option<[F; 2 * CHUNK]> = None;
330 let mut rolling_hash = [F::ZERO; 2 * CHUNK];
331
332 let mut is_first_in_segment = true;
333
334 loop {
335 let mut cells = vec![];
336 for chunk_elem in rolling_hash.iter_mut().take(CHUNK) {
337 let read_row_pointer_and_length = if is_first_in_segment
338 || row_pointer == row_end
339 {
340 if is_first_in_segment {
341 is_first_in_segment = false;
342 } else {
343 opened_index += 1;
344 if opened_index == opened_length
345 || memory.unsafe_read_cell(
346 address_space,
347 dim_base_pointer
348 + F::from_canonical_usize(opened_index),
349 ) != F::from_canonical_u32(log_height as u32)
350 {
351 break;
352 }
353 }
354 let (result, [new_row_pointer, row_len]) = memory.read(
355 address_space,
356 opened_base_pointer + F::from_canonical_usize(2 * opened_index),
357 );
358 row_pointer = new_row_pointer.as_canonical_u32() as usize;
359 row_end = row_pointer
360 + (opened_element_size * row_len).as_canonical_u32() as usize;
361 Some(result)
362 } else {
363 memory.increment_timestamp();
364 None
365 };
366 let (read, value) = memory
367 .read_cell(address_space, F::from_canonical_usize(row_pointer));
368 cells.push(CellRecord {
369 read,
370 opened_index,
371 read_row_pointer_and_length,
372 row_pointer,
373 row_end,
374 });
375 *chunk_elem = value;
376 row_pointer += 1;
377 }
378 if cells.is_empty() {
379 break;
380 }
381 let cells_len = cells.len();
382 chunks.push(InsideRowRecord {
383 cells,
384 p2_input: rolling_hash,
385 });
386 self.height += 1;
387 prev_rolling_hash = Some(rolling_hash);
388 self.subchip.permute_mut(&mut rolling_hash);
389 if cells_len < CHUNK {
390 for _ in 0..CHUNK - cells_len {
391 memory.increment_timestamp();
392 memory.increment_timestamp();
393 }
394 break;
395 }
396 }
397 let final_opened_index = opened_index - 1;
398 let (initial_height_read, height_check) = memory.read_cell(
399 address_space,
400 dim_base_pointer + F::from_canonical_usize(initial_opened_index),
401 );
402 assert_eq!(height_check, F::from_canonical_u32(log_height as u32));
403 let (final_height_read, height_check) = memory.read_cell(
404 address_space,
405 dim_base_pointer + F::from_canonical_usize(final_opened_index),
406 );
407 assert_eq!(height_check, F::from_canonical_u32(log_height as u32));
408
409 let hash: [F; CHUNK] = std::array::from_fn(|i| rolling_hash[i]);
410
411 let (p2_input, new_root) = if log_height as u32 == initial_log_height {
412 (prev_rolling_hash.unwrap(), hash)
413 } else {
414 self.compress(root, hash)
415 };
416 root = new_root;
417
418 self.height += 1;
419 Some(IncorporateRowRecord {
420 chunks,
421 initial_opened_index,
422 final_opened_index,
423 initial_height_read,
424 final_height_read,
425 p2_input,
426 })
427 } else {
428 None
429 };
430
431 let incorporate_sibling = if log_height == 0 {
432 None
433 } else {
434 for _ in 0..NUM_INITIAL_READS {
435 memory.increment_timestamp();
436 }
437
438 let (read_sibling_is_on_right, sibling_is_on_right) = memory.read_cell(
439 address_space,
440 index_base_pointer + F::from_canonical_usize(sibling_index),
441 );
442 let sibling_is_on_right = sibling_is_on_right == F::ONE;
443 let sibling = sibling_proof[sibling_index];
444 let (p2_input, new_root) = if sibling_is_on_right {
445 self.compress(sibling, root)
446 } else {
447 self.compress(root, sibling)
448 };
449 root = new_root;
450
451 self.height += 1;
452 Some(IncorporateSiblingRecord {
453 read_sibling_is_on_right,
454 sibling_is_on_right,
455 p2_input,
456 })
457 };
458
459 top_level.push(TopLevelRecord {
460 incorporate_row,
461 incorporate_sibling,
462 });
463
464 log_height -= 1;
465 sibling_index += 1;
466 }
467
468 assert_eq!(commit, root);
469 self.record_set
470 .verify_batch_records
471 .push(VerifyBatchRecord {
472 from_state,
473 instruction: instruction.clone(),
474 dim_base_pointer,
475 opened_base_pointer,
476 opened_length,
477 index_base_pointer,
478 commit_pointer,
479 dim_base_pointer_read,
480 opened_base_pointer_read,
481 opened_length_read,
482 index_base_pointer_read,
483 commit_pointer_read,
484 commit_read,
485 initial_log_height: initial_log_height as usize,
486 top_level,
487 });
488 } else {
489 unreachable!()
490 }
491 Ok(ExecutionState {
492 pc: from_state.pc + DEFAULT_PC_STEP,
493 timestamp: memory.timestamp(),
494 })
495 }
496
497 fn get_opcode_name(&self, opcode: usize) -> String {
498 if opcode == VERIFY_BATCH.global_opcode().as_usize() {
499 String::from("VERIFY_BATCH")
500 } else if opcode == PERM_POS2.global_opcode().as_usize() {
501 String::from("PERM_POS2")
502 } else if opcode == COMP_POS2.global_opcode().as_usize() {
503 String::from("COMP_POS2")
504 } else {
505 unreachable!("unsupported opcode: {}", opcode)
506 }
507 }
508}