openvm_algebra_circuit/
fields.rs

1use halo2curves_axiom::ff::PrimeField;
2use num_bigint::BigUint;
3use num_traits::Num;
4
5#[derive(Debug, Clone, Copy, PartialEq, Eq)]
6pub enum FieldType {
7    K256Coordinate = 0,
8    K256Scalar = 1,
9    P256Coordinate = 2,
10    P256Scalar = 3,
11    BN254Coordinate = 4,
12    BN254Scalar = 5,
13    BLS12_381Coordinate = 6,
14    BLS12_381Scalar = 7,
15}
16
17#[derive(Debug, Clone, Copy, PartialEq, Eq)]
18pub enum Operation {
19    Add = 0,
20    Sub = 1,
21    Mul = 2,
22    Div = 3,
23}
24
25fn get_modulus_as_bigint<F: PrimeField>() -> BigUint {
26    BigUint::from_str_radix(F::MODULUS.trim_start_matches("0x"), 16).unwrap()
27}
28
29pub fn get_field_type(modulus: &BigUint) -> Option<FieldType> {
30    if modulus == &get_modulus_as_bigint::<halo2curves_axiom::secq256k1::Fq>() {
31        return Some(FieldType::K256Coordinate);
32    }
33
34    if modulus == &get_modulus_as_bigint::<halo2curves_axiom::secq256k1::Fp>() {
35        return Some(FieldType::K256Scalar);
36    }
37
38    if modulus == &get_modulus_as_bigint::<halo2curves_axiom::secp256r1::Fp>() {
39        return Some(FieldType::P256Coordinate);
40    }
41
42    if modulus == &get_modulus_as_bigint::<halo2curves_axiom::secp256r1::Fq>() {
43        return Some(FieldType::P256Scalar);
44    }
45
46    if modulus == &get_modulus_as_bigint::<halo2curves_axiom::bn256::Fq>() {
47        return Some(FieldType::BN254Coordinate);
48    }
49
50    if modulus == &get_modulus_as_bigint::<halo2curves_axiom::bn256::Fr>() {
51        return Some(FieldType::BN254Scalar);
52    }
53
54    if modulus == &get_modulus_as_bigint::<halo2curves_axiom::bls12_381::Fq>() {
55        return Some(FieldType::BLS12_381Coordinate);
56    }
57
58    if modulus == &get_modulus_as_bigint::<halo2curves_axiom::bls12_381::Fr>() {
59        return Some(FieldType::BLS12_381Scalar);
60    }
61
62    None
63}
64
65pub fn get_fp2_field_type(modulus: &BigUint) -> Option<FieldType> {
66    if modulus == &get_modulus_as_bigint::<halo2curves_axiom::bn256::Fq>() {
67        return Some(FieldType::BN254Coordinate);
68    }
69
70    if modulus == &get_modulus_as_bigint::<halo2curves_axiom::bls12_381::Fq>() {
71        return Some(FieldType::BLS12_381Coordinate);
72    }
73
74    None
75}
76
77#[inline(always)]
78pub fn field_operation<
79    const FIELD: u8,
80    const BLOCKS: usize,
81    const BLOCK_SIZE: usize,
82    const OP: u8,
83>(
84    input_data: [[[u8; BLOCK_SIZE]; BLOCKS]; 2],
85) -> [[u8; BLOCK_SIZE]; BLOCKS] {
86    match FIELD {
87        x if x == FieldType::K256Coordinate as u8 => {
88            field_operation_256bit::<halo2curves_axiom::secq256k1::Fq, BLOCKS, BLOCK_SIZE, OP>(
89                input_data,
90            )
91        }
92        x if x == FieldType::K256Scalar as u8 => {
93            field_operation_256bit::<halo2curves_axiom::secq256k1::Fp, BLOCKS, BLOCK_SIZE, OP>(
94                input_data,
95            )
96        }
97        x if x == FieldType::P256Coordinate as u8 => {
98            field_operation_256bit::<halo2curves_axiom::secp256r1::Fp, BLOCKS, BLOCK_SIZE, OP>(
99                input_data,
100            )
101        }
102        x if x == FieldType::P256Scalar as u8 => {
103            field_operation_256bit::<halo2curves_axiom::secp256r1::Fq, BLOCKS, BLOCK_SIZE, OP>(
104                input_data,
105            )
106        }
107        x if x == FieldType::BN254Coordinate as u8 => {
108            field_operation_256bit::<halo2curves_axiom::bn256::Fq, BLOCKS, BLOCK_SIZE, OP>(
109                input_data,
110            )
111        }
112        x if x == FieldType::BN254Scalar as u8 => {
113            field_operation_256bit::<halo2curves_axiom::bn256::Fr, BLOCKS, BLOCK_SIZE, OP>(
114                input_data,
115            )
116        }
117        x if x == FieldType::BLS12_381Coordinate as u8 => {
118            field_operation_bls12_381_coordinate::<BLOCKS, BLOCK_SIZE, OP>(input_data)
119        }
120        x if x == FieldType::BLS12_381Scalar as u8 => {
121            field_operation_256bit::<halo2curves_axiom::bls12_381::Fr, BLOCKS, BLOCK_SIZE, OP>(
122                input_data,
123            )
124        }
125        _ => panic!("Unsupported field type: {}", FIELD),
126    }
127}
128
129#[inline(always)]
130pub fn fp2_operation<
131    const FIELD: u8,
132    const BLOCKS: usize,
133    const BLOCK_SIZE: usize,
134    const OP: u8,
135>(
136    input_data: [[[u8; BLOCK_SIZE]; BLOCKS]; 2],
137) -> [[u8; BLOCK_SIZE]; BLOCKS] {
138    match FIELD {
139        x if x == FieldType::BN254Coordinate as u8 => {
140            fp2_operation_bn254::<BLOCKS, BLOCK_SIZE, OP>(input_data)
141        }
142        x if x == FieldType::BLS12_381Coordinate as u8 => {
143            fp2_operation_bls12_381::<BLOCKS, BLOCK_SIZE, OP>(input_data)
144        }
145        _ => panic!("Unsupported field type for Fp2: {}", FIELD),
146    }
147}
148
149#[inline(always)]
150fn field_operation_256bit<
151    F: PrimeField<Repr = [u8; 32]>,
152    const BLOCKS: usize,
153    const BLOCK_SIZE: usize,
154    const OP: u8,
155>(
156    input_data: [[[u8; BLOCK_SIZE]; BLOCKS]; 2],
157) -> [[u8; BLOCK_SIZE]; BLOCKS] {
158    let a = blocks_to_field_element::<F>(input_data[0].as_flattened());
159    let b = blocks_to_field_element::<F>(input_data[1].as_flattened());
160    let c = match OP {
161        x if x == Operation::Add as u8 => a + b,
162        x if x == Operation::Sub as u8 => a - b,
163        x if x == Operation::Mul as u8 => a * b,
164        x if x == Operation::Div as u8 => a * b.invert().unwrap(),
165        _ => panic!("Unsupported operation: {}", OP),
166    };
167
168    let mut output = [[0u8; BLOCK_SIZE]; BLOCKS];
169    field_element_to_blocks(&c, &mut output);
170    output
171}
172
173#[inline(always)]
174fn field_operation_bls12_381_coordinate<
175    const BLOCKS: usize,
176    const BLOCK_SIZE: usize,
177    const OP: u8,
178>(
179    input_data: [[[u8; BLOCK_SIZE]; BLOCKS]; 2],
180) -> [[u8; BLOCK_SIZE]; BLOCKS] {
181    let a = blocks_to_field_element_bls12_381_coordinate(input_data[0].as_flattened());
182    let b = blocks_to_field_element_bls12_381_coordinate(input_data[1].as_flattened());
183    let c = match OP {
184        x if x == Operation::Add as u8 => a + b,
185        x if x == Operation::Sub as u8 => a - b,
186        x if x == Operation::Mul as u8 => a * b,
187        x if x == Operation::Div as u8 => a * b.invert().unwrap(),
188        _ => panic!("Unsupported operation: {}", OP),
189    };
190
191    let mut output = [[0u8; BLOCK_SIZE]; BLOCKS];
192    field_element_to_blocks_bls12_381_coordinate(&c, &mut output);
193    output
194}
195
196#[inline(always)]
197fn fp2_operation_bn254<const BLOCKS: usize, const BLOCK_SIZE: usize, const OP: u8>(
198    input_data: [[[u8; BLOCK_SIZE]; BLOCKS]; 2],
199) -> [[u8; BLOCK_SIZE]; BLOCKS] {
200    let a = blocks_to_fp2_bn254::<BLOCKS, BLOCK_SIZE>(input_data[0].as_ref());
201    let b = blocks_to_fp2_bn254::<BLOCKS, BLOCK_SIZE>(input_data[1].as_ref());
202    let c = match OP {
203        x if x == Operation::Add as u8 => a + b,
204        x if x == Operation::Sub as u8 => a - b,
205        x if x == Operation::Mul as u8 => a * b,
206        x if x == Operation::Div as u8 => a * b.invert().unwrap(),
207        _ => panic!("Unsupported operation: {}", OP),
208    };
209
210    let mut output = [[0u8; BLOCK_SIZE]; BLOCKS];
211    fp2_to_blocks_bn254(&c, &mut output);
212    output
213}
214
215#[inline(always)]
216fn fp2_operation_bls12_381<const BLOCKS: usize, const BLOCK_SIZE: usize, const OP: u8>(
217    input_data: [[[u8; BLOCK_SIZE]; BLOCKS]; 2],
218) -> [[u8; BLOCK_SIZE]; BLOCKS] {
219    let a = blocks_to_fp2_bls12_381::<BLOCKS, BLOCK_SIZE>(input_data[0].as_ref());
220    let b = blocks_to_fp2_bls12_381::<BLOCKS, BLOCK_SIZE>(input_data[1].as_ref());
221    let c = match OP {
222        x if x == Operation::Add as u8 => a + b,
223        x if x == Operation::Sub as u8 => a - b,
224        x if x == Operation::Mul as u8 => a * b,
225        x if x == Operation::Div as u8 => a * b.invert().unwrap(),
226        _ => panic!("Unsupported operation: {}", OP),
227    };
228
229    let mut output = [[0u8; BLOCK_SIZE]; BLOCKS];
230    fp2_to_blocks_bls12_381(&c, &mut output);
231    output
232}
233
234#[inline(always)]
235fn from_repr_with_reduction<F: PrimeField<Repr = [u8; 32]>>(bytes: [u8; 32]) -> F {
236    F::from_repr_vartime(bytes).unwrap_or_else(|| {
237        // Reduce modulo the field's modulus for non-canonical representations
238        let modulus = get_modulus_as_bigint::<F>();
239        let value = BigUint::from_bytes_le(&bytes);
240        let reduced = value % modulus;
241
242        let reduced_le_bytes = reduced.to_bytes_le();
243        let mut reduced_bytes = [0u8; 32];
244        reduced_bytes[..reduced_le_bytes.len()]
245            .copy_from_slice(&reduced_le_bytes[..reduced_le_bytes.len()]);
246
247        F::from_repr_vartime(reduced_bytes).unwrap()
248    })
249}
250
251#[inline(always)]
252fn from_repr_with_reduction_bls12_381_coordinate(
253    bytes: [u8; 48],
254) -> halo2curves_axiom::bls12_381::Fq {
255    halo2curves_axiom::bls12_381::Fq::from_bytes(&bytes).unwrap_or_else(|| {
256        // Reduce modulo the field's modulus for non-canonical representations
257        let modulus = get_modulus_as_bigint::<halo2curves_axiom::bls12_381::Fq>();
258        let value = BigUint::from_bytes_le(&bytes);
259        let reduced = value % modulus;
260
261        let reduced_le_bytes = reduced.to_bytes_le();
262        let mut reduced_bytes = [0u8; 48];
263        reduced_bytes[..reduced_le_bytes.len()]
264            .copy_from_slice(&reduced_le_bytes[..reduced_le_bytes.len()]);
265
266        halo2curves_axiom::bls12_381::Fq::from_bytes(&reduced_bytes).unwrap()
267    })
268}
269
270#[inline(always)]
271pub fn blocks_to_field_element<F: PrimeField<Repr = [u8; 32]>>(blocks: &[u8]) -> F {
272    debug_assert!(blocks.len() == 32);
273    let mut bytes = [0u8; 32];
274    bytes[..blocks.len()].copy_from_slice(&blocks[..blocks.len()]);
275
276    from_repr_with_reduction::<F>(bytes)
277}
278
279#[inline(always)]
280pub fn field_element_to_blocks<F: PrimeField<Repr = [u8; 32]>, const BLOCK_SIZE: usize>(
281    field_element: &F,
282    output: &mut [[u8; BLOCK_SIZE]],
283) {
284    debug_assert!(output.len() * BLOCK_SIZE == 32);
285    let bytes = field_element.to_repr();
286    let mut byte_idx = 0;
287
288    for block in output.iter_mut() {
289        for byte in block.iter_mut() {
290            *byte = if byte_idx < bytes.len() {
291                bytes[byte_idx]
292            } else {
293                0
294            };
295            byte_idx += 1;
296        }
297    }
298}
299
300#[inline(always)]
301pub fn blocks_to_field_element_bls12_381_coordinate(
302    blocks: &[u8],
303) -> halo2curves_axiom::bls12_381::Fq {
304    debug_assert!(blocks.len() == 48);
305    let mut bytes = [0u8; 48];
306    bytes[..blocks.len()].copy_from_slice(&blocks[..blocks.len()]);
307
308    from_repr_with_reduction_bls12_381_coordinate(bytes)
309}
310
311#[inline(always)]
312pub fn field_element_to_blocks_bls12_381_coordinate<const BLOCK_SIZE: usize>(
313    field_element: &halo2curves_axiom::bls12_381::Fq,
314    output: &mut [[u8; BLOCK_SIZE]],
315) {
316    debug_assert!(output.len() * BLOCK_SIZE == 48);
317    let bytes = field_element.to_bytes();
318    let mut byte_idx = 0;
319
320    for block in output.iter_mut() {
321        for byte in block.iter_mut() {
322            *byte = if byte_idx < bytes.len() {
323                bytes[byte_idx]
324            } else {
325                0
326            };
327            byte_idx += 1;
328        }
329    }
330}
331
332#[inline(always)]
333fn blocks_to_fp2_bn254<const BLOCKS: usize, const BLOCK_SIZE: usize>(
334    blocks: &[[u8; BLOCK_SIZE]],
335) -> halo2curves_axiom::bn256::Fq2 {
336    let c0 = blocks_to_field_element::<halo2curves_axiom::bn256::Fq>(
337        blocks[..BLOCKS / 2].as_flattened(),
338    );
339    let c1 = blocks_to_field_element::<halo2curves_axiom::bn256::Fq>(
340        blocks[BLOCKS / 2..].as_flattened(),
341    );
342    halo2curves_axiom::bn256::Fq2::new(c0, c1)
343}
344
345#[inline(always)]
346fn fp2_to_blocks_bn254<const BLOCKS: usize, const BLOCK_SIZE: usize>(
347    fp2: &halo2curves_axiom::bn256::Fq2,
348    output: &mut [[u8; BLOCK_SIZE]; BLOCKS],
349) {
350    field_element_to_blocks::<halo2curves_axiom::bn256::Fq, BLOCK_SIZE>(
351        &fp2.c0,
352        &mut output[..BLOCKS / 2],
353    );
354    field_element_to_blocks::<halo2curves_axiom::bn256::Fq, BLOCK_SIZE>(
355        &fp2.c1,
356        &mut output[BLOCKS / 2..],
357    );
358}
359
360#[inline(always)]
361fn blocks_to_fp2_bls12_381<const BLOCKS: usize, const BLOCK_SIZE: usize>(
362    blocks: &[[u8; BLOCK_SIZE]],
363) -> halo2curves_axiom::bls12_381::Fq2 {
364    let c0 = blocks_to_field_element_bls12_381_coordinate(blocks[..BLOCKS / 2].as_flattened());
365    let c1 = blocks_to_field_element_bls12_381_coordinate(blocks[BLOCKS / 2..].as_flattened());
366    halo2curves_axiom::bls12_381::Fq2 { c0, c1 }
367}
368
369#[inline(always)]
370fn fp2_to_blocks_bls12_381<const BLOCKS: usize, const BLOCK_SIZE: usize>(
371    fp2: &halo2curves_axiom::bls12_381::Fq2,
372    output: &mut [[u8; BLOCK_SIZE]; BLOCKS],
373) {
374    field_element_to_blocks_bls12_381_coordinate(&fp2.c0, &mut output[..BLOCKS / 2]);
375    field_element_to_blocks_bls12_381_coordinate(&fp2.c1, &mut output[BLOCKS / 2..]);
376}