openvm_sha256_air/
columns.rs

1//! WARNING: the order of fields in the structs is important, do not change it
2
3use openvm_circuit_primitives::{utils::not, AlignedBorrow};
4use openvm_stark_backend::p3_field::FieldAlgebra;
5
6use super::{
7    SHA256_HASH_WORDS, SHA256_ROUNDS_PER_ROW, SHA256_ROW_VAR_CNT, SHA256_WORD_BITS,
8    SHA256_WORD_U16S, SHA256_WORD_U8S,
9};
10
11/// In each SHA256 block:
12/// - First 16 rows use Sha256RoundCols
13/// - Final row uses Sha256DigestCols
14///
15/// Note that for soundness, we require that there is always a padding row after the last digest row
16/// in the trace. Right now, this is true because the unpadded height is a multiple of 17, and thus
17/// not a power of 2.
18///
19/// Sha256RoundCols and Sha256DigestCols share the same first 3 fields:
20/// - flags
21/// - work_vars/hash (same type, different name)
22/// - schedule_helper
23///
24/// This design allows for:
25/// 1. Common constraints to work on either struct type by accessing these shared fields
26/// 2. Specific constraints to use the appropriate struct, with flags helping to do conditional
27///    constraints
28///
29/// Note that the `Sha256WorkVarsCols` field it is used for different purposes in the two structs.
30#[repr(C)]
31#[derive(Clone, Copy, Debug, AlignedBorrow)]
32pub struct Sha256RoundCols<T> {
33    pub flags: Sha256FlagsCols<T>,
34    /// Stores the current state of the working variables
35    pub work_vars: Sha256WorkVarsCols<T>,
36    pub schedule_helper: Sha256MessageHelperCols<T>,
37    pub message_schedule: Sha256MessageScheduleCols<T>,
38}
39
40#[repr(C)]
41#[derive(Clone, Copy, Debug, AlignedBorrow)]
42pub struct Sha256DigestCols<T> {
43    pub flags: Sha256FlagsCols<T>,
44    /// Will serve as previous hash values for the next block.
45    ///     - on non-last blocks, this is the final hash of the current block
46    ///     - on last blocks, this is the initial state constants, SHA256_H.
47    /// The work variables constraints are applied on all rows, so `carry_a` and `carry_e`
48    /// must be filled in with dummy values to ensure these constraints hold.
49    pub hash: Sha256WorkVarsCols<T>,
50    pub schedule_helper: Sha256MessageHelperCols<T>,
51    /// The actual final hash values of the given block
52    /// Note: the above `hash` will be equal to `final_hash` unless we are on the last block
53    pub final_hash: [[T; SHA256_WORD_U8S]; SHA256_HASH_WORDS],
54    /// The final hash of the previous block
55    /// Note: will be constrained using interactions with the chip itself
56    pub prev_hash: [[T; SHA256_WORD_U16S]; SHA256_HASH_WORDS],
57}
58
59#[repr(C)]
60#[derive(Clone, Copy, Debug, AlignedBorrow)]
61pub struct Sha256MessageScheduleCols<T> {
62    /// The message schedule words as 32-bit integers
63    /// The first 16 words will be the message data
64    pub w: [[T; SHA256_WORD_BITS]; SHA256_ROUNDS_PER_ROW],
65    /// Will be message schedule carries for rows 4..16 and a buffer for rows 0..4 to be used
66    /// freely by wrapper chips Note: carries are 2 bit numbers represented using 2 cells as
67    /// individual bits
68    pub carry_or_buffer: [[T; SHA256_WORD_U8S]; SHA256_ROUNDS_PER_ROW],
69}
70
71#[repr(C)]
72#[derive(Clone, Copy, Debug, AlignedBorrow)]
73pub struct Sha256WorkVarsCols<T> {
74    /// `a` and `e` after each iteration as 32-bits
75    pub a: [[T; SHA256_WORD_BITS]; SHA256_ROUNDS_PER_ROW],
76    pub e: [[T; SHA256_WORD_BITS]; SHA256_ROUNDS_PER_ROW],
77    /// The carry's used for addition during each iteration when computing `a` and `e`
78    pub carry_a: [[T; SHA256_WORD_U16S]; SHA256_ROUNDS_PER_ROW],
79    pub carry_e: [[T; SHA256_WORD_U16S]; SHA256_ROUNDS_PER_ROW],
80}
81
82/// These are the columns that are used to help with the message schedule additions
83/// Note: these need to be correctly assigned for every row even on padding rows
84#[repr(C)]
85#[derive(Clone, Copy, Debug, AlignedBorrow)]
86pub struct Sha256MessageHelperCols<T> {
87    /// The following are used to move data forward to constrain the message schedule additions
88    /// The value of `w` (message schedule word) from 3 rounds ago
89    /// In general, `w_i` means `w` from `i` rounds ago
90    pub w_3: [[T; SHA256_WORD_U16S]; SHA256_ROUNDS_PER_ROW - 1],
91    /// Here intermediate(i) =  w_i + sig_0(w_{i+1})
92    /// Intermed_t represents the intermediate t rounds ago
93    /// This is needed to constrain the message schedule, since we can only constrain on two rows
94    /// at a time
95    pub intermed_4: [[T; SHA256_WORD_U16S]; SHA256_ROUNDS_PER_ROW],
96    pub intermed_8: [[T; SHA256_WORD_U16S]; SHA256_ROUNDS_PER_ROW],
97    pub intermed_12: [[T; SHA256_WORD_U16S]; SHA256_ROUNDS_PER_ROW],
98}
99
100#[repr(C)]
101#[derive(Clone, Copy, Debug, AlignedBorrow)]
102pub struct Sha256FlagsCols<T> {
103    /// A flag that indicates if the current row is among the first 16 rows of a block.
104    pub is_round_row: T,
105    /// A flag that indicates if the current row is among the first 4 rows of a block.
106    pub is_first_4_rows: T,
107    /// A flag that indicates if the current row is the last (17th) row of a block.
108    pub is_digest_row: T,
109    // A flag that indicates if the current row is the last block of the message.
110    // This flag is only used in digest rows.
111    pub is_last_block: T,
112    /// We will encode the row index [0..17) using 5 cells
113    pub row_idx: [T; SHA256_ROW_VAR_CNT],
114    /// The index of the current block in the trace starting at 1.
115    /// Set to 0 on padding rows.
116    pub global_block_idx: T,
117    /// The index of the current block in the current message starting at 0.
118    /// Resets after every message.
119    /// Set to 0 on padding rows.
120    pub local_block_idx: T,
121}
122
123impl<O, T: Copy + core::ops::Add<Output = O>> Sha256FlagsCols<T> {
124    // This refers to the padding rows that are added to the air to make the trace length a power of
125    // 2. Not to be confused with the padding added to messages as part of the SHA hash
126    // function.
127    pub fn is_not_padding_row(&self) -> O {
128        self.is_round_row + self.is_digest_row
129    }
130
131    // This refers to the padding rows that are added to the air to make the trace length a power of
132    // 2. Not to be confused with the padding added to messages as part of the SHA hash
133    // function.
134    pub fn is_padding_row(&self) -> O
135    where
136        O: FieldAlgebra,
137    {
138        not(self.is_not_padding_row())
139    }
140}