openvm_circuit/system/poseidon2/
trace.rs

1use std::{borrow::BorrowMut, sync::Arc};
2
3use openvm_circuit_primitives::utils::next_power_of_two_or_zero;
4use openvm_stark_backend::{
5    config::{StarkGenericConfig, Val},
6    p3_air::BaseAir,
7    p3_field::{FieldAlgebra, PrimeField32},
8    p3_matrix::dense::RowMajorMatrix,
9    p3_maybe_rayon::prelude::*,
10    prover::{cpu::CpuBackend, types::AirProvingContext},
11    Chip, ChipUsageGetter,
12};
13
14use super::{columns::*, Poseidon2PeripheryBaseChip, PERIPHERY_POSEIDON2_WIDTH};
15
16impl<RA, SC: StarkGenericConfig, const SBOX_REGISTERS: usize> Chip<RA, CpuBackend<SC>>
17    for Poseidon2PeripheryBaseChip<Val<SC>, SBOX_REGISTERS>
18where
19    Val<SC>: PrimeField32,
20{
21    /// Generates trace and clears internal records state.
22    fn generate_proving_ctx(&self, _: RA) -> AirProvingContext<CpuBackend<SC>> {
23        let height = next_power_of_two_or_zero(self.current_trace_height());
24        let width = self.trace_width();
25
26        let mut inputs = Vec::with_capacity(height);
27        let mut multiplicities = Vec::with_capacity(height);
28        #[cfg(feature = "parallel")]
29        let records_iter = self.records.par_iter();
30        #[cfg(not(feature = "parallel"))]
31        let records_iter = self.records.iter();
32        let (actual_inputs, actual_multiplicities): (Vec<_>, Vec<_>) = records_iter
33            .map(|r| {
34                let (input, mult) = r.pair();
35                (*input, mult.load(std::sync::atomic::Ordering::Relaxed))
36            })
37            .unzip();
38        inputs.extend(actual_inputs);
39        multiplicities.extend(actual_multiplicities);
40        inputs.resize(height, [Val::<SC>::ZERO; PERIPHERY_POSEIDON2_WIDTH]);
41        multiplicities.resize(height, 0);
42
43        // TODO: this would be more optimal if plonky3 made the generate_trace_row function public
44        let inner_trace = self.subchip.generate_trace(inputs);
45        let inner_width = self.air.subair.width();
46
47        let mut values = Val::<SC>::zero_vec(height * width);
48        values
49            .par_chunks_mut(width)
50            .zip(inner_trace.values.par_chunks(inner_width))
51            .zip(multiplicities)
52            .for_each(|((row, inner_row), mult)| {
53                // WARNING: Poseidon2SubCols must be the first field in Poseidon2PeripheryCols
54                row[..inner_width].copy_from_slice(inner_row);
55                let cols: &mut Poseidon2PeripheryCols<Val<SC>, SBOX_REGISTERS> = row.borrow_mut();
56                cols.mult = Val::<SC>::from_canonical_u32(mult);
57            });
58        self.records.clear();
59
60        AirProvingContext::simple_no_pis(Arc::new(RowMajorMatrix::new(values, width)))
61    }
62}
63
64impl<F: PrimeField32, const SBOX_REGISTERS: usize> ChipUsageGetter
65    for Poseidon2PeripheryBaseChip<F, SBOX_REGISTERS>
66{
67    fn air_name(&self) -> String {
68        format!("Poseidon2PeripheryAir<F, {}>", SBOX_REGISTERS)
69    }
70
71    fn current_trace_height(&self) -> usize {
72        if self.nonempty.load(std::sync::atomic::Ordering::Relaxed) {
73            // Not to call `DashMap::len` too often
74            self.records.len()
75        } else {
76            0
77        }
78    }
79
80    fn trace_width(&self) -> usize {
81        self.air.width()
82    }
83}