openvm_native_circuit/poseidon2/
chip.rs

1use std::sync::{Arc, Mutex};
2
3use openvm_circuit::{
4    arch::{
5        ExecutionBridge, ExecutionError, ExecutionState, InstructionExecutor, Streams, SystemPort,
6    },
7    system::memory::{MemoryController, OfflineMemory, RecordId},
8};
9use openvm_instructions::{instruction::Instruction, program::DEFAULT_PC_STEP, LocalOpcode};
10use openvm_native_compiler::{
11    conversion::AS,
12    Poseidon2Opcode::{COMP_POS2, PERM_POS2},
13    VerifyBatchOpcode::VERIFY_BATCH,
14};
15use openvm_poseidon2_air::{Poseidon2Config, Poseidon2SubAir, Poseidon2SubChip};
16use openvm_stark_backend::{
17    p3_field::{Field, PrimeField32},
18    p3_maybe_rayon::prelude::{ParallelIterator, ParallelSlice},
19};
20use serde::{Deserialize, Serialize};
21
22use crate::poseidon2::{
23    air::{NativePoseidon2Air, VerifyBatchBus},
24    CHUNK,
25};
26
27#[derive(Debug, Clone, Serialize, Deserialize)]
28#[serde(bound = "F: Field")]
29pub struct VerifyBatchRecord<F: Field> {
30    pub from_state: ExecutionState<u32>,
31    pub instruction: Instruction<F>,
32
33    pub dim_base_pointer: F,
34    pub opened_base_pointer: F,
35    pub opened_length: usize,
36    pub index_base_pointer: F,
37    pub commit_pointer: F,
38
39    pub dim_base_pointer_read: RecordId,
40    pub opened_base_pointer_read: RecordId,
41    pub opened_length_read: RecordId,
42    pub index_base_pointer_read: RecordId,
43    pub commit_pointer_read: RecordId,
44
45    pub commit_read: RecordId,
46    pub initial_log_height: usize,
47    pub top_level: Vec<TopLevelRecord<F>>,
48}
49
50impl<F: PrimeField32> VerifyBatchRecord<F> {
51    pub fn opened_element_size_inv(&self) -> F {
52        self.instruction.g
53    }
54}
55
56#[derive(Debug, Clone, Serialize, Deserialize)]
57#[serde(bound = "F: Field")]
58pub struct TopLevelRecord<F: Field> {
59    // must be present in first record
60    pub incorporate_row: Option<IncorporateRowRecord<F>>,
61    // must be present in all bust last record
62    pub incorporate_sibling: Option<IncorporateSiblingRecord<F>>,
63}
64
65#[repr(C)]
66#[derive(Debug, Clone, Serialize, Deserialize)]
67#[serde(bound = "F: Field")]
68pub struct IncorporateSiblingRecord<F: Field> {
69    pub read_sibling_is_on_right: RecordId,
70    pub sibling_is_on_right: bool,
71    pub p2_input: [F; 2 * CHUNK],
72}
73
74#[derive(Debug, Clone, Serialize, Deserialize)]
75#[serde(bound = "F: Field")]
76pub struct IncorporateRowRecord<F: Field> {
77    pub chunks: Vec<InsideRowRecord<F>>,
78    pub initial_opened_index: usize,
79    pub final_opened_index: usize,
80    pub initial_height_read: RecordId,
81    pub final_height_read: RecordId,
82    pub p2_input: [F; 2 * CHUNK],
83}
84
85#[derive(Debug, Clone, Serialize, Deserialize)]
86#[serde(bound = "F: Field")]
87pub struct InsideRowRecord<F: Field> {
88    pub cells: Vec<CellRecord>,
89    pub p2_input: [F; 2 * CHUNK],
90}
91
92#[repr(C)]
93#[derive(Debug, Clone, Serialize, Deserialize)]
94pub struct CellRecord {
95    pub read: RecordId,
96    pub opened_index: usize,
97    pub read_row_pointer_and_length: Option<RecordId>,
98    pub row_pointer: usize,
99    pub row_end: usize,
100}
101
102#[repr(C)]
103#[derive(Debug, Clone, Serialize, Deserialize)]
104#[serde(bound = "F: Field")]
105pub struct SimplePoseidonRecord<F: Field> {
106    pub from_state: ExecutionState<u32>,
107    pub instruction: Instruction<F>,
108
109    pub read_input_pointer_1: RecordId,
110    pub read_input_pointer_2: Option<RecordId>,
111    pub read_output_pointer: RecordId,
112    pub read_data_1: RecordId,
113    pub read_data_2: RecordId,
114    pub write_data_1: RecordId,
115    pub write_data_2: Option<RecordId>,
116
117    pub input_pointer_1: F,
118    pub input_pointer_2: F,
119    pub output_pointer: F,
120    pub p2_input: [F; 2 * CHUNK],
121}
122
123#[derive(Debug, Clone, Serialize, Deserialize, Default)]
124#[serde(bound = "F: Field")]
125pub struct NativePoseidon2RecordSet<F: Field> {
126    pub verify_batch_records: Vec<VerifyBatchRecord<F>>,
127    pub simple_permute_records: Vec<SimplePoseidonRecord<F>>,
128}
129
130pub struct NativePoseidon2Chip<F: Field, const SBOX_REGISTERS: usize> {
131    pub(super) air: NativePoseidon2Air<F, SBOX_REGISTERS>,
132    pub record_set: NativePoseidon2RecordSet<F>,
133    pub height: usize,
134    pub(super) offline_memory: Arc<Mutex<OfflineMemory<F>>>,
135    pub(super) subchip: Poseidon2SubChip<F, SBOX_REGISTERS>,
136    pub(super) streams: Arc<Mutex<Streams<F>>>,
137}
138
139impl<F: PrimeField32, const SBOX_REGISTERS: usize> NativePoseidon2Chip<F, SBOX_REGISTERS> {
140    pub fn new(
141        port: SystemPort,
142        offline_memory: Arc<Mutex<OfflineMemory<F>>>,
143        poseidon2_config: Poseidon2Config<F>,
144        verify_batch_bus: VerifyBatchBus,
145        streams: Arc<Mutex<Streams<F>>>,
146    ) -> Self {
147        let air = NativePoseidon2Air {
148            execution_bridge: ExecutionBridge::new(port.execution_bus, port.program_bus),
149            memory_bridge: port.memory_bridge,
150            internal_bus: verify_batch_bus,
151            subair: Arc::new(Poseidon2SubAir::new(poseidon2_config.constants.into())),
152            address_space: F::from_canonical_u32(AS::Native as u32),
153        };
154        Self {
155            record_set: Default::default(),
156            air,
157            height: 0,
158            offline_memory,
159            subchip: Poseidon2SubChip::new(poseidon2_config.constants),
160            streams,
161        }
162    }
163
164    fn compress(&self, left: [F; CHUNK], right: [F; CHUNK]) -> ([F; 2 * CHUNK], [F; CHUNK]) {
165        let concatenated =
166            std::array::from_fn(|i| if i < CHUNK { left[i] } else { right[i - CHUNK] });
167        let permuted = self.subchip.permute(concatenated);
168        (concatenated, std::array::from_fn(|i| permuted[i]))
169    }
170}
171
172pub(super) const NUM_INITIAL_READS: usize = 6;
173pub(super) const NUM_SIMPLE_ACCESSES: u32 = 7;
174
175impl<F: PrimeField32, const SBOX_REGISTERS: usize> InstructionExecutor<F>
176    for NativePoseidon2Chip<F, SBOX_REGISTERS>
177{
178    fn execute(
179        &mut self,
180        memory: &mut MemoryController<F>,
181        instruction: &Instruction<F>,
182        from_state: ExecutionState<u32>,
183    ) -> Result<ExecutionState<u32>, ExecutionError> {
184        if instruction.opcode == PERM_POS2.global_opcode()
185            || instruction.opcode == COMP_POS2.global_opcode()
186        {
187            let &Instruction {
188                a: output_register,
189                b: input_register_1,
190                c: input_register_2,
191                d: register_address_space,
192                e: data_address_space,
193                ..
194            } = instruction;
195
196            let (read_output_pointer, output_pointer) =
197                memory.read_cell(register_address_space, output_register);
198            let (read_input_pointer_1, input_pointer_1) =
199                memory.read_cell(register_address_space, input_register_1);
200            let (read_input_pointer_2, input_pointer_2) =
201                if instruction.opcode == PERM_POS2.global_opcode() {
202                    memory.increment_timestamp();
203                    (None, input_pointer_1 + F::from_canonical_usize(CHUNK))
204                } else {
205                    let (read_input_pointer_2, input_pointer_2) =
206                        memory.read_cell(register_address_space, input_register_2);
207                    (Some(read_input_pointer_2), input_pointer_2)
208                };
209            let (read_data_1, data_1) = memory.read::<CHUNK>(data_address_space, input_pointer_1);
210            let (read_data_2, data_2) = memory.read::<CHUNK>(data_address_space, input_pointer_2);
211            let p2_input = std::array::from_fn(|i| {
212                if i < CHUNK {
213                    data_1[i]
214                } else {
215                    data_2[i - CHUNK]
216                }
217            });
218            let output = self.subchip.permute(p2_input);
219            let (write_data_1, _) = memory.write::<CHUNK>(
220                data_address_space,
221                output_pointer,
222                std::array::from_fn(|i| output[i]),
223            );
224            let write_data_2 = if instruction.opcode == PERM_POS2.global_opcode() {
225                Some(
226                    memory
227                        .write::<CHUNK>(
228                            data_address_space,
229                            output_pointer + F::from_canonical_usize(CHUNK),
230                            std::array::from_fn(|i| output[CHUNK + i]),
231                        )
232                        .0,
233                )
234            } else {
235                memory.increment_timestamp();
236                None
237            };
238
239            assert_eq!(
240                memory.timestamp(),
241                from_state.timestamp + NUM_SIMPLE_ACCESSES
242            );
243
244            self.record_set
245                .simple_permute_records
246                .push(SimplePoseidonRecord {
247                    from_state,
248                    instruction: instruction.clone(),
249                    read_input_pointer_1,
250                    read_input_pointer_2,
251                    read_output_pointer,
252                    read_data_1,
253                    read_data_2,
254                    write_data_1,
255                    write_data_2,
256                    input_pointer_1,
257                    input_pointer_2,
258                    output_pointer,
259                    p2_input,
260                });
261            self.height += 1;
262        } else if instruction.opcode == VERIFY_BATCH.global_opcode() {
263            let &Instruction {
264                a: dim_register,
265                b: opened_register,
266                c: opened_length_register,
267                d: proof_id_ptr,
268                e: index_register,
269                f: commit_register,
270                g: opened_element_size_inv,
271                ..
272            } = instruction;
273            let address_space = self.air.address_space;
274            // calc inverse fast assuming opened_element_size in {1, 4}
275            let mut opened_element_size = F::ONE;
276            while opened_element_size * opened_element_size_inv != F::ONE {
277                opened_element_size += F::ONE;
278            }
279
280            let proof_id = memory.unsafe_read_cell(address_space, proof_id_ptr);
281            let (dim_base_pointer_read, dim_base_pointer) =
282                memory.read_cell(address_space, dim_register);
283            let (opened_base_pointer_read, opened_base_pointer) =
284                memory.read_cell(address_space, opened_register);
285            let (opened_length_read, opened_length) =
286                memory.read_cell(address_space, opened_length_register);
287            let (index_base_pointer_read, index_base_pointer) =
288                memory.read_cell(address_space, index_register);
289            let (commit_pointer_read, commit_pointer) =
290                memory.read_cell(address_space, commit_register);
291            let (commit_read, commit) = memory.read(address_space, commit_pointer);
292
293            let opened_length = opened_length.as_canonical_u32() as usize;
294
295            let initial_log_height = memory
296                .unsafe_read_cell(address_space, dim_base_pointer)
297                .as_canonical_u32();
298            let mut log_height = initial_log_height as i32;
299            let mut sibling_index = 0;
300            let mut opened_index = 0;
301            let mut top_level = vec![];
302
303            let mut root = [F::ZERO; CHUNK];
304            let sibling_proof: Vec<[F; CHUNK]> = {
305                let streams = self.streams.lock().unwrap();
306                let proof_idx = proof_id.as_canonical_u32() as usize;
307                streams.hint_space[proof_idx]
308                    .par_chunks(CHUNK)
309                    .map(|c| c.try_into().unwrap())
310                    .collect()
311            };
312
313            while log_height >= 0 {
314                let incorporate_row = if opened_index < opened_length
315                    && memory.unsafe_read_cell(
316                        address_space,
317                        dim_base_pointer + F::from_canonical_usize(opened_index),
318                    ) == F::from_canonical_u32(log_height as u32)
319                {
320                    let initial_opened_index = opened_index;
321                    for _ in 0..NUM_INITIAL_READS {
322                        memory.increment_timestamp();
323                    }
324                    let mut chunks = vec![];
325
326                    let mut row_pointer = 0;
327                    let mut row_end = 0;
328
329                    let mut prev_rolling_hash: Option<[F; 2 * CHUNK]> = None;
330                    let mut rolling_hash = [F::ZERO; 2 * CHUNK];
331
332                    let mut is_first_in_segment = true;
333
334                    loop {
335                        let mut cells = vec![];
336                        for chunk_elem in rolling_hash.iter_mut().take(CHUNK) {
337                            let read_row_pointer_and_length = if is_first_in_segment
338                                || row_pointer == row_end
339                            {
340                                if is_first_in_segment {
341                                    is_first_in_segment = false;
342                                } else {
343                                    opened_index += 1;
344                                    if opened_index == opened_length
345                                        || memory.unsafe_read_cell(
346                                            address_space,
347                                            dim_base_pointer
348                                                + F::from_canonical_usize(opened_index),
349                                        ) != F::from_canonical_u32(log_height as u32)
350                                    {
351                                        break;
352                                    }
353                                }
354                                let (result, [new_row_pointer, row_len]) = memory.read(
355                                    address_space,
356                                    opened_base_pointer + F::from_canonical_usize(2 * opened_index),
357                                );
358                                row_pointer = new_row_pointer.as_canonical_u32() as usize;
359                                row_end = row_pointer
360                                    + (opened_element_size * row_len).as_canonical_u32() as usize;
361                                Some(result)
362                            } else {
363                                memory.increment_timestamp();
364                                None
365                            };
366                            let (read, value) = memory
367                                .read_cell(address_space, F::from_canonical_usize(row_pointer));
368                            cells.push(CellRecord {
369                                read,
370                                opened_index,
371                                read_row_pointer_and_length,
372                                row_pointer,
373                                row_end,
374                            });
375                            *chunk_elem = value;
376                            row_pointer += 1;
377                        }
378                        if cells.is_empty() {
379                            break;
380                        }
381                        let cells_len = cells.len();
382                        chunks.push(InsideRowRecord {
383                            cells,
384                            p2_input: rolling_hash,
385                        });
386                        self.height += 1;
387                        prev_rolling_hash = Some(rolling_hash);
388                        self.subchip.permute_mut(&mut rolling_hash);
389                        if cells_len < CHUNK {
390                            for _ in 0..CHUNK - cells_len {
391                                memory.increment_timestamp();
392                                memory.increment_timestamp();
393                            }
394                            break;
395                        }
396                    }
397                    let final_opened_index = opened_index - 1;
398                    let (initial_height_read, height_check) = memory.read_cell(
399                        address_space,
400                        dim_base_pointer + F::from_canonical_usize(initial_opened_index),
401                    );
402                    assert_eq!(height_check, F::from_canonical_u32(log_height as u32));
403                    let (final_height_read, height_check) = memory.read_cell(
404                        address_space,
405                        dim_base_pointer + F::from_canonical_usize(final_opened_index),
406                    );
407                    assert_eq!(height_check, F::from_canonical_u32(log_height as u32));
408
409                    let hash: [F; CHUNK] = std::array::from_fn(|i| rolling_hash[i]);
410
411                    let (p2_input, new_root) = if log_height as u32 == initial_log_height {
412                        (prev_rolling_hash.unwrap(), hash)
413                    } else {
414                        self.compress(root, hash)
415                    };
416                    root = new_root;
417
418                    self.height += 1;
419                    Some(IncorporateRowRecord {
420                        chunks,
421                        initial_opened_index,
422                        final_opened_index,
423                        initial_height_read,
424                        final_height_read,
425                        p2_input,
426                    })
427                } else {
428                    None
429                };
430
431                let incorporate_sibling = if log_height == 0 {
432                    None
433                } else {
434                    for _ in 0..NUM_INITIAL_READS {
435                        memory.increment_timestamp();
436                    }
437
438                    let (read_sibling_is_on_right, sibling_is_on_right) = memory.read_cell(
439                        address_space,
440                        index_base_pointer + F::from_canonical_usize(sibling_index),
441                    );
442                    let sibling_is_on_right = sibling_is_on_right == F::ONE;
443                    let sibling = sibling_proof[sibling_index];
444                    let (p2_input, new_root) = if sibling_is_on_right {
445                        self.compress(sibling, root)
446                    } else {
447                        self.compress(root, sibling)
448                    };
449                    root = new_root;
450
451                    self.height += 1;
452                    Some(IncorporateSiblingRecord {
453                        read_sibling_is_on_right,
454                        sibling_is_on_right,
455                        p2_input,
456                    })
457                };
458
459                top_level.push(TopLevelRecord {
460                    incorporate_row,
461                    incorporate_sibling,
462                });
463
464                log_height -= 1;
465                sibling_index += 1;
466            }
467
468            assert_eq!(commit, root);
469            self.record_set
470                .verify_batch_records
471                .push(VerifyBatchRecord {
472                    from_state,
473                    instruction: instruction.clone(),
474                    dim_base_pointer,
475                    opened_base_pointer,
476                    opened_length,
477                    index_base_pointer,
478                    commit_pointer,
479                    dim_base_pointer_read,
480                    opened_base_pointer_read,
481                    opened_length_read,
482                    index_base_pointer_read,
483                    commit_pointer_read,
484                    commit_read,
485                    initial_log_height: initial_log_height as usize,
486                    top_level,
487                });
488        } else {
489            unreachable!()
490        }
491        Ok(ExecutionState {
492            pc: from_state.pc + DEFAULT_PC_STEP,
493            timestamp: memory.timestamp(),
494        })
495    }
496
497    fn get_opcode_name(&self, opcode: usize) -> String {
498        if opcode == VERIFY_BATCH.global_opcode().as_usize() {
499            String::from("VERIFY_BATCH")
500        } else if opcode == PERM_POS2.global_opcode().as_usize() {
501            String::from("PERM_POS2")
502        } else if opcode == COMP_POS2.global_opcode().as_usize() {
503            String::from("COMP_POS2")
504        } else {
505            unreachable!("unsupported opcode: {}", opcode)
506        }
507    }
508}