openvm_circuit/system/poseidon2/
trace.rs

1use std::{borrow::BorrowMut, sync::Arc};
2
3use openvm_circuit_primitives::utils::next_power_of_two_or_zero;
4use openvm_stark_backend::{
5    config::{StarkGenericConfig, Val},
6    p3_air::BaseAir,
7    p3_field::PrimeCharacteristicRing,
8    p3_matrix::dense::RowMajorMatrix,
9    p3_maybe_rayon::prelude::*,
10    prover::{cpu::CpuBackend, types::AirProvingContext},
11    Chip, ChipUsageGetter,
12};
13
14use super::{columns::*, Poseidon2PeripheryBaseChip, PERIPHERY_POSEIDON2_WIDTH};
15use crate::arch::VmField;
16
17impl<RA, SC: StarkGenericConfig, const SBOX_REGISTERS: usize> Chip<RA, CpuBackend<SC>>
18    for Poseidon2PeripheryBaseChip<Val<SC>, SBOX_REGISTERS>
19where
20    Val<SC>: VmField,
21{
22    /// Generates trace and clears internal records state.
23    fn generate_proving_ctx(&self, _: RA) -> AirProvingContext<CpuBackend<SC>> {
24        let height = next_power_of_two_or_zero(self.current_trace_height());
25        let width = self.trace_width();
26
27        let mut inputs = Vec::with_capacity(height);
28        let mut multiplicities = Vec::with_capacity(height);
29        #[cfg(feature = "parallel")]
30        let records_iter = self.records.par_iter();
31        #[cfg(not(feature = "parallel"))]
32        let records_iter = self.records.iter();
33        let (actual_inputs, actual_multiplicities): (Vec<_>, Vec<_>) = records_iter
34            .map(|r| {
35                let (input, mult) = r.pair();
36                (*input, mult.load(std::sync::atomic::Ordering::Relaxed))
37            })
38            .unzip();
39        inputs.extend(actual_inputs);
40        multiplicities.extend(actual_multiplicities);
41        inputs.resize(height, [Val::<SC>::ZERO; PERIPHERY_POSEIDON2_WIDTH]);
42        multiplicities.resize(height, 0);
43
44        // TODO: this would be more optimal if plonky3 made the generate_trace_row function public
45        let inner_trace = self.subchip.generate_trace(inputs);
46        let inner_width = self.air.subair.width();
47
48        let mut values = Val::<SC>::zero_vec(height * width);
49        values
50            .par_chunks_mut(width)
51            .zip(inner_trace.values.par_chunks(inner_width))
52            .zip(multiplicities)
53            .for_each(|((row, inner_row), mult)| {
54                // WARNING: Poseidon2SubCols must be the first field in Poseidon2PeripheryCols
55                row[..inner_width].copy_from_slice(inner_row);
56                let cols: &mut Poseidon2PeripheryCols<Val<SC>, SBOX_REGISTERS> = row.borrow_mut();
57                cols.mult = Val::<SC>::from_u32(mult);
58            });
59        self.records.clear();
60
61        AirProvingContext::simple_no_pis(Arc::new(RowMajorMatrix::new(values, width)))
62    }
63}
64
65impl<F: VmField, const SBOX_REGISTERS: usize> ChipUsageGetter
66    for Poseidon2PeripheryBaseChip<F, SBOX_REGISTERS>
67{
68    fn air_name(&self) -> String {
69        format!("Poseidon2PeripheryAir<F, {SBOX_REGISTERS}>")
70    }
71
72    fn current_trace_height(&self) -> usize {
73        if self.nonempty.load(std::sync::atomic::Ordering::Relaxed) {
74            // Not to call `DashMap::len` too often
75            self.records.len()
76        } else {
77            0
78        }
79    }
80
81    fn trace_width(&self) -> usize {
82        self.air.width()
83    }
84}