openvm_circuit/system/poseidon2/
trace.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
use std::{borrow::BorrowMut, sync::Arc};

use openvm_circuit_primitives::utils::next_power_of_two_or_zero;
use openvm_stark_backend::{
    config::{StarkGenericConfig, Val},
    p3_air::BaseAir,
    p3_field::{AbstractField, PrimeField32},
    p3_matrix::dense::RowMajorMatrix,
    p3_maybe_rayon::prelude::*,
    prover::types::AirProofInput,
    rap::{get_air_name, AnyRap},
    Chip, ChipUsageGetter,
};

use super::{columns::*, Poseidon2PeripheryBaseChip, PERIPHERY_POSEIDON2_WIDTH};

impl<SC: StarkGenericConfig, const SBOX_REGISTERS: usize> Chip<SC>
    for Poseidon2PeripheryBaseChip<Val<SC>, SBOX_REGISTERS>
where
    Val<SC>: PrimeField32,
{
    fn air(&self) -> Arc<dyn AnyRap<SC>> {
        self.air.clone()
    }

    fn generate_air_proof_input(self) -> AirProofInput<SC> {
        let air = self.air();
        let height = next_power_of_two_or_zero(self.current_trace_height());
        let width = self.trace_width();

        let mut inputs = Vec::with_capacity(height);
        let mut multiplicities = Vec::with_capacity(height);
        let (actual_inputs, actual_multiplicities): (Vec<_>, Vec<_>) = self
            .records
            .into_par_iter()
            .map(|(input, mult)| (input, mult.load(std::sync::atomic::Ordering::Relaxed)))
            .unzip();
        inputs.extend(actual_inputs);
        multiplicities.extend(actual_multiplicities);
        inputs.resize(height, [Val::<SC>::ZERO; PERIPHERY_POSEIDON2_WIDTH]);
        multiplicities.resize(height, 0);

        // TODO: this would be more optimal if plonky3 made the generate_trace_row function public
        let inner_trace = self.subchip.generate_trace(inputs);
        let inner_width = self.air.subair.width();

        let mut values = Val::<SC>::zero_vec(height * width);
        values
            .par_chunks_mut(width)
            .zip(inner_trace.values.par_chunks(inner_width))
            .zip(multiplicities)
            .for_each(|((row, inner_row), mult)| {
                // WARNING: Poseidon2SubCols must be the first field in Poseidon2PeripheryCols
                row[..inner_width].copy_from_slice(inner_row);
                let cols: &mut Poseidon2PeripheryCols<Val<SC>, SBOX_REGISTERS> = row.borrow_mut();
                cols.mult = Val::<SC>::from_canonical_u32(mult);
            });

        AirProofInput::simple_no_pis(air, RowMajorMatrix::new(values, width))
    }
}

impl<F: PrimeField32, const SBOX_REGISTERS: usize> ChipUsageGetter
    for Poseidon2PeripheryBaseChip<F, SBOX_REGISTERS>
{
    fn air_name(&self) -> String {
        get_air_name(&self.air)
    }

    fn current_trace_height(&self) -> usize {
        self.records.len()
    }

    fn trace_width(&self) -> usize {
        self.air.width()
    }
}