openvm_circuit/system/poseidon2/
trace.rs

1use std::borrow::BorrowMut;
2
3use openvm_circuit_primitives::utils::next_power_of_two_or_zero;
4use openvm_stark_backend::{
5    config::{StarkGenericConfig, Val},
6    p3_air::BaseAir,
7    p3_field::{FieldAlgebra, PrimeField32},
8    p3_matrix::dense::RowMajorMatrix,
9    p3_maybe_rayon::prelude::*,
10    prover::types::AirProofInput,
11    rap::get_air_name,
12    AirRef, Chip, ChipUsageGetter,
13};
14
15use super::{columns::*, Poseidon2PeripheryBaseChip, PERIPHERY_POSEIDON2_WIDTH};
16
17impl<SC: StarkGenericConfig, const SBOX_REGISTERS: usize> Chip<SC>
18    for Poseidon2PeripheryBaseChip<Val<SC>, SBOX_REGISTERS>
19where
20    Val<SC>: PrimeField32,
21{
22    fn air(&self) -> AirRef<SC> {
23        self.air.clone()
24    }
25
26    fn generate_air_proof_input(self) -> AirProofInput<SC> {
27        let height = next_power_of_two_or_zero(self.current_trace_height());
28        let width = self.trace_width();
29
30        let mut inputs = Vec::with_capacity(height);
31        let mut multiplicities = Vec::with_capacity(height);
32        let (actual_inputs, actual_multiplicities): (Vec<_>, Vec<_>) = self
33            .records
34            .into_par_iter()
35            .map(|(input, mult)| (input, mult.load(std::sync::atomic::Ordering::Relaxed)))
36            .unzip();
37        inputs.extend(actual_inputs);
38        multiplicities.extend(actual_multiplicities);
39        inputs.resize(height, [Val::<SC>::ZERO; PERIPHERY_POSEIDON2_WIDTH]);
40        multiplicities.resize(height, 0);
41
42        // TODO: this would be more optimal if plonky3 made the generate_trace_row function public
43        let inner_trace = self.subchip.generate_trace(inputs);
44        let inner_width = self.air.subair.width();
45
46        let mut values = Val::<SC>::zero_vec(height * width);
47        values
48            .par_chunks_mut(width)
49            .zip(inner_trace.values.par_chunks(inner_width))
50            .zip(multiplicities)
51            .for_each(|((row, inner_row), mult)| {
52                // WARNING: Poseidon2SubCols must be the first field in Poseidon2PeripheryCols
53                row[..inner_width].copy_from_slice(inner_row);
54                let cols: &mut Poseidon2PeripheryCols<Val<SC>, SBOX_REGISTERS> = row.borrow_mut();
55                cols.mult = Val::<SC>::from_canonical_u32(mult);
56            });
57
58        AirProofInput::simple_no_pis(RowMajorMatrix::new(values, width))
59    }
60}
61
62impl<F: PrimeField32, const SBOX_REGISTERS: usize> ChipUsageGetter
63    for Poseidon2PeripheryBaseChip<F, SBOX_REGISTERS>
64{
65    fn air_name(&self) -> String {
66        get_air_name(&self.air)
67    }
68
69    fn current_trace_height(&self) -> usize {
70        self.records.len()
71    }
72
73    fn trace_width(&self) -> usize {
74        self.air.width()
75    }
76}