openvm_stark_backend/air_builders/debug/
check_constraints.rs

1use itertools::izip;
2use p3_air::BaseAir;
3use p3_field::{Field, FieldAlgebra};
4use p3_matrix::{dense::RowMajorMatrixView, stack::VerticalPair, Matrix};
5use p3_maybe_rayon::prelude::*;
6
7use crate::{
8    air_builders::debug::DebugConstraintBuilder,
9    config::{StarkGenericConfig, Val},
10    interaction::{
11        debug::{generate_logical_interactions, LogicalInteractions},
12        RapPhaseSeqKind, SymbolicInteraction,
13    },
14    rap::{PartitionedBaseAir, Rap},
15};
16
17/// Check that all constraints vanish on the subgroup.
18#[allow(clippy::too_many_arguments)]
19pub fn check_constraints<R, SC>(
20    rap: &R,
21    rap_name: &str,
22    preprocessed: &Option<RowMajorMatrixView<Val<SC>>>,
23    partitioned_main: &[RowMajorMatrixView<Val<SC>>],
24    public_values: &[Val<SC>],
25) where
26    R: for<'a> Rap<DebugConstraintBuilder<'a, SC>>
27        + BaseAir<Val<SC>>
28        + PartitionedBaseAir<Val<SC>>
29        + ?Sized,
30    SC: StarkGenericConfig,
31{
32    let height = partitioned_main[0].height();
33    assert!(partitioned_main.iter().all(|mat| mat.height() == height));
34
35    // Check that constraints are satisfied.
36    (0..height).into_par_iter().for_each(|i| {
37        let i_next = (i + 1) % height;
38
39        let (preprocessed_local, preprocessed_next) = preprocessed
40            .as_ref()
41            .map(|preprocessed| {
42                (
43                    preprocessed.row_slice(i).to_vec(),
44                    preprocessed.row_slice(i_next).to_vec(),
45                )
46            })
47            .unwrap_or((vec![], vec![]));
48
49        let partitioned_main_row_pair = partitioned_main
50            .iter()
51            .map(|part| (part.row_slice(i), part.row_slice(i_next)))
52            .collect::<Vec<_>>();
53        let partitioned_main = partitioned_main_row_pair
54            .iter()
55            .map(|(local, next)| {
56                VerticalPair::new(
57                    RowMajorMatrixView::new_row(local),
58                    RowMajorMatrixView::new_row(next),
59                )
60            })
61            .collect::<Vec<_>>();
62
63        let mut builder = DebugConstraintBuilder {
64            air_name: rap_name,
65            row_index: i,
66            preprocessed: VerticalPair::new(
67                RowMajorMatrixView::new_row(preprocessed_local.as_slice()),
68                RowMajorMatrixView::new_row(preprocessed_next.as_slice()),
69            ),
70            partitioned_main,
71            after_challenge: vec![], // unreachable
72            challenges: &[],         // unreachable
73            public_values,
74            exposed_values_after_challenge: &[], // unreachable
75            is_first_row: Val::<SC>::ZERO,
76            is_last_row: Val::<SC>::ZERO,
77            is_transition: Val::<SC>::ONE,
78            rap_phase_seq_kind: RapPhaseSeqKind::FriLogUp, // unused
79            has_common_main: rap.common_main_width() > 0,
80        };
81        if i == 0 {
82            builder.is_first_row = Val::<SC>::ONE;
83        }
84        if i == height - 1 {
85            builder.is_last_row = Val::<SC>::ONE;
86            builder.is_transition = Val::<SC>::ZERO;
87        }
88
89        rap.eval(&mut builder);
90    });
91}
92
93pub fn check_logup<F: Field>(
94    air_names: &[String],
95    interactions: &[Vec<SymbolicInteraction<F>>],
96    preprocessed: &[Option<RowMajorMatrixView<F>>],
97    partitioned_main: &[Vec<RowMajorMatrixView<F>>],
98    public_values: &[Vec<F>],
99) {
100    let mut logical_interactions = LogicalInteractions::<F>::default();
101    for (air_idx, (interactions, preprocessed, partitioned_main, public_values)) in
102        izip!(interactions, preprocessed, partitioned_main, public_values).enumerate()
103    {
104        generate_logical_interactions(
105            air_idx,
106            interactions,
107            preprocessed,
108            partitioned_main,
109            public_values,
110            &mut logical_interactions,
111        );
112    }
113
114    let mut logup_failed = false;
115    // For each bus, check each `fields` key by summing up multiplicities.
116    for (bus_idx, bus_interactions) in logical_interactions.at_bus.into_iter() {
117        for (fields, connections) in bus_interactions.into_iter() {
118            let sum: F = connections.iter().map(|(_, count)| *count).sum();
119            if !sum.is_zero() {
120                logup_failed = true;
121                println!(
122                    "Bus {} failed to balance the multiplicities for fields={:?}. The bus connections for this were:",
123                    bus_idx, fields
124                );
125                for (air_idx, count) in connections {
126                    println!(
127                        "   Air idx: {}, Air name: {}, count: {:?}",
128                        air_idx, air_names[air_idx], count
129                    );
130                }
131            }
132        }
133    }
134    if logup_failed {
135        panic!("LogUp multiset equality check failed.");
136    }
137}