openvm_algebra_circuit/
fields.rs

1use halo2curves_axiom::ff::{Field, PrimeField};
2use num_bigint::BigUint;
3use num_traits::Num;
4
5#[derive(Debug, Clone, Copy, PartialEq, Eq)]
6pub enum FieldType {
7    K256Coordinate = 0,
8    K256Scalar = 1,
9    P256Coordinate = 2,
10    P256Scalar = 3,
11    BN254Coordinate = 4,
12    BN254Scalar = 5,
13    BLS12_381Coordinate = 6,
14    BLS12_381Scalar = 7,
15}
16
17#[derive(Debug, Clone, Copy, PartialEq, Eq)]
18pub enum Operation {
19    Add = 0,
20    Sub = 1,
21    Mul = 2,
22    Div = 3,
23}
24
25// TODO: hardcode this. it's slow
26fn get_modulus_as_bigint<F: PrimeField>() -> BigUint {
27    BigUint::from_str_radix(F::MODULUS.trim_start_matches("0x"), 16).unwrap()
28}
29
30pub fn get_field_type(modulus: &BigUint) -> Option<FieldType> {
31    if modulus == &get_modulus_as_bigint::<halo2curves_axiom::secq256k1::Fq>() {
32        return Some(FieldType::K256Coordinate);
33    }
34
35    if modulus == &get_modulus_as_bigint::<halo2curves_axiom::secq256k1::Fp>() {
36        return Some(FieldType::K256Scalar);
37    }
38
39    if modulus == &get_modulus_as_bigint::<halo2curves_axiom::secp256r1::Fp>() {
40        return Some(FieldType::P256Coordinate);
41    }
42
43    if modulus == &get_modulus_as_bigint::<halo2curves_axiom::secp256r1::Fq>() {
44        return Some(FieldType::P256Scalar);
45    }
46
47    if modulus == &get_modulus_as_bigint::<halo2curves_axiom::bn256::Fq>() {
48        return Some(FieldType::BN254Coordinate);
49    }
50
51    if modulus == &get_modulus_as_bigint::<halo2curves_axiom::bn256::Fr>() {
52        return Some(FieldType::BN254Scalar);
53    }
54
55    if modulus == &get_modulus_as_bigint::<halo2curves_axiom::bls12_381::Fq>() {
56        return Some(FieldType::BLS12_381Coordinate);
57    }
58
59    if modulus == &get_modulus_as_bigint::<halo2curves_axiom::bls12_381::Fr>() {
60        return Some(FieldType::BLS12_381Scalar);
61    }
62
63    None
64}
65
66pub fn get_fp2_field_type(modulus: &BigUint) -> Option<FieldType> {
67    if modulus == &get_modulus_as_bigint::<halo2curves_axiom::bn256::Fq>() {
68        return Some(FieldType::BN254Coordinate);
69    }
70
71    if modulus == &get_modulus_as_bigint::<halo2curves_axiom::bls12_381::Fq>() {
72        return Some(FieldType::BLS12_381Coordinate);
73    }
74
75    None
76}
77
78#[inline(always)]
79pub fn field_operation<
80    const FIELD: u8,
81    const BLOCKS: usize,
82    const BLOCK_SIZE: usize,
83    const OP: u8,
84>(
85    input_data: [[[u8; BLOCK_SIZE]; BLOCKS]; 2],
86) -> [[u8; BLOCK_SIZE]; BLOCKS] {
87    match FIELD {
88        x if x == FieldType::K256Coordinate as u8 => {
89            field_operation_256bit::<halo2curves_axiom::secq256k1::Fq, BLOCKS, BLOCK_SIZE, OP>(
90                input_data,
91            )
92        }
93        x if x == FieldType::K256Scalar as u8 => {
94            field_operation_256bit::<halo2curves_axiom::secq256k1::Fp, BLOCKS, BLOCK_SIZE, OP>(
95                input_data,
96            )
97        }
98        x if x == FieldType::P256Coordinate as u8 => {
99            field_operation_256bit::<halo2curves_axiom::secp256r1::Fp, BLOCKS, BLOCK_SIZE, OP>(
100                input_data,
101            )
102        }
103        x if x == FieldType::P256Scalar as u8 => {
104            field_operation_256bit::<halo2curves_axiom::secp256r1::Fq, BLOCKS, BLOCK_SIZE, OP>(
105                input_data,
106            )
107        }
108        x if x == FieldType::BN254Coordinate as u8 => {
109            field_operation_256bit::<halo2curves_axiom::bn256::Fq, BLOCKS, BLOCK_SIZE, OP>(
110                input_data,
111            )
112        }
113        x if x == FieldType::BN254Scalar as u8 => {
114            field_operation_256bit::<halo2curves_axiom::bn256::Fr, BLOCKS, BLOCK_SIZE, OP>(
115                input_data,
116            )
117        }
118        x if x == FieldType::BLS12_381Coordinate as u8 => {
119            field_operation_bls12_381_coordinate::<BLOCKS, BLOCK_SIZE, OP>(input_data)
120        }
121        x if x == FieldType::BLS12_381Scalar as u8 => {
122            field_operation_256bit::<halo2curves_axiom::bls12_381::Fr, BLOCKS, BLOCK_SIZE, OP>(
123                input_data,
124            )
125        }
126        _ => panic!("Unsupported field type: {FIELD}"),
127    }
128}
129
130#[inline(always)]
131pub fn fp2_operation<
132    const FIELD: u8,
133    const BLOCKS: usize,
134    const BLOCK_SIZE: usize,
135    const OP: u8,
136>(
137    input_data: [[[u8; BLOCK_SIZE]; BLOCKS]; 2],
138) -> [[u8; BLOCK_SIZE]; BLOCKS] {
139    match FIELD {
140        x if x == FieldType::BN254Coordinate as u8 => {
141            fp2_operation_bn254::<BLOCKS, BLOCK_SIZE, OP>(input_data)
142        }
143        x if x == FieldType::BLS12_381Coordinate as u8 => {
144            fp2_operation_bls12_381::<BLOCKS, BLOCK_SIZE, OP>(input_data)
145        }
146        _ => panic!("Unsupported field type for Fp2: {FIELD}"),
147    }
148}
149
150#[inline(always)]
151fn field_operation_256bit<
152    F: PrimeField<Repr = [u8; 32]>,
153    const BLOCKS: usize,
154    const BLOCK_SIZE: usize,
155    const OP: u8,
156>(
157    input_data: [[[u8; BLOCK_SIZE]; BLOCKS]; 2],
158) -> [[u8; BLOCK_SIZE]; BLOCKS] {
159    let a = blocks_to_field_element::<F>(input_data[0].as_flattened());
160    let b = blocks_to_field_element::<F>(input_data[1].as_flattened());
161    let c = match OP {
162        x if x == Operation::Add as u8 => a + b,
163        x if x == Operation::Sub as u8 => a - b,
164        x if x == Operation::Mul as u8 => a * b,
165        x if x == Operation::Div as u8 => a * b.invert().unwrap(),
166        _ => panic!("Unsupported operation: {OP}"),
167    };
168
169    let mut output = [[0u8; BLOCK_SIZE]; BLOCKS];
170    field_element_to_blocks(&c, &mut output);
171    output
172}
173
174#[inline(always)]
175fn field_operation_bls12_381_coordinate<
176    const BLOCKS: usize,
177    const BLOCK_SIZE: usize,
178    const OP: u8,
179>(
180    input_data: [[[u8; BLOCK_SIZE]; BLOCKS]; 2],
181) -> [[u8; BLOCK_SIZE]; BLOCKS] {
182    let a = blocks_to_field_element_bls12_381_coordinate(input_data[0].as_flattened());
183    let b = blocks_to_field_element_bls12_381_coordinate(input_data[1].as_flattened());
184    let c = match OP {
185        x if x == Operation::Add as u8 => a + b,
186        x if x == Operation::Sub as u8 => a - b,
187        x if x == Operation::Mul as u8 => a * b,
188        x if x == Operation::Div as u8 => a * b.invert().unwrap(),
189        _ => panic!("Unsupported operation: {OP}"),
190    };
191
192    let mut output = [[0u8; BLOCK_SIZE]; BLOCKS];
193    field_element_to_blocks_bls12_381_coordinate(&c, &mut output);
194    output
195}
196
197#[inline(always)]
198fn fp2_operation_bn254<const BLOCKS: usize, const BLOCK_SIZE: usize, const OP: u8>(
199    input_data: [[[u8; BLOCK_SIZE]; BLOCKS]; 2],
200) -> [[u8; BLOCK_SIZE]; BLOCKS] {
201    let a = blocks_to_fp2_bn254::<BLOCKS, BLOCK_SIZE>(input_data[0].as_ref());
202    let b = blocks_to_fp2_bn254::<BLOCKS, BLOCK_SIZE>(input_data[1].as_ref());
203    let c = match OP {
204        x if x == Operation::Add as u8 => a + b,
205        x if x == Operation::Sub as u8 => a - b,
206        x if x == Operation::Mul as u8 => a * b,
207        x if x == Operation::Div as u8 => a * b.invert().unwrap(),
208        _ => panic!("Unsupported operation: {OP}"),
209    };
210
211    let mut output = [[0u8; BLOCK_SIZE]; BLOCKS];
212    fp2_to_blocks_bn254(&c, &mut output);
213    output
214}
215
216#[inline(always)]
217fn fp2_operation_bls12_381<const BLOCKS: usize, const BLOCK_SIZE: usize, const OP: u8>(
218    input_data: [[[u8; BLOCK_SIZE]; BLOCKS]; 2],
219) -> [[u8; BLOCK_SIZE]; BLOCKS] {
220    let a = blocks_to_fp2_bls12_381::<BLOCKS, BLOCK_SIZE>(input_data[0].as_ref());
221    let b = blocks_to_fp2_bls12_381::<BLOCKS, BLOCK_SIZE>(input_data[1].as_ref());
222    let c = match OP {
223        x if x == Operation::Add as u8 => a + b,
224        x if x == Operation::Sub as u8 => a - b,
225        x if x == Operation::Mul as u8 => a * b,
226        x if x == Operation::Div as u8 => a * b.invert().unwrap(),
227        _ => panic!("Unsupported operation: {OP}"),
228    };
229
230    let mut output = [[0u8; BLOCK_SIZE]; BLOCKS];
231    fp2_to_blocks_bls12_381(&c, &mut output);
232    output
233}
234
235#[inline(always)]
236fn from_repr_with_reduction<F: PrimeField<Repr = [u8; 32]>>(bytes: [u8; 32]) -> F {
237    F::from_repr_vartime(bytes).unwrap_or_else(|| {
238        // Reduce modulo the field's modulus for non-canonical representations
239        let modulus = get_modulus_as_bigint::<F>();
240        let value = BigUint::from_bytes_le(&bytes);
241        let reduced = value % modulus;
242
243        let reduced_le_bytes = reduced.to_bytes_le();
244        let mut reduced_bytes = [0u8; 32];
245        reduced_bytes[..reduced_le_bytes.len()]
246            .copy_from_slice(&reduced_le_bytes[..reduced_le_bytes.len()]);
247
248        F::from_repr_vartime(reduced_bytes).unwrap()
249    })
250}
251
252#[inline(always)]
253fn from_repr_with_reduction_bls12_381_coordinate(bytes: [u8; 48]) -> blstrs::Fp {
254    blstrs::Fp::from_bytes_le(&bytes).unwrap_or_else(|| {
255        // Reduce modulo the field's modulus for non-canonical representations
256        let modulus = BigUint::from_bytes_le(&blstrs::Fp::char());
257        let value = BigUint::from_bytes_le(&bytes);
258        let reduced = value % modulus;
259
260        let reduced_le_bytes = reduced.to_bytes_le();
261        let mut reduced_bytes = [0u8; 48];
262        reduced_bytes[..reduced_le_bytes.len()]
263            .copy_from_slice(&reduced_le_bytes[..reduced_le_bytes.len()]);
264
265        blstrs::Fp::from_bytes_le(&reduced_bytes).unwrap()
266    })
267}
268
269#[inline(always)]
270pub fn blocks_to_field_element<F: PrimeField<Repr = [u8; 32]>>(blocks: &[u8]) -> F {
271    debug_assert!(blocks.len() == 32);
272    let mut bytes = [0u8; 32];
273    bytes[..blocks.len()].copy_from_slice(&blocks[..blocks.len()]);
274
275    from_repr_with_reduction::<F>(bytes)
276}
277
278#[inline(always)]
279pub fn field_element_to_blocks<F: PrimeField<Repr = [u8; 32]>, const BLOCK_SIZE: usize>(
280    field_element: &F,
281    output: &mut [[u8; BLOCK_SIZE]],
282) {
283    debug_assert!(output.len() * BLOCK_SIZE == 32);
284    let bytes = field_element.to_repr();
285    let mut byte_idx = 0;
286
287    for block in output.iter_mut() {
288        for byte in block.iter_mut() {
289            *byte = if byte_idx < bytes.len() {
290                bytes[byte_idx]
291            } else {
292                0
293            };
294            byte_idx += 1;
295        }
296    }
297}
298
299#[inline(always)]
300pub fn blocks_to_field_element_bls12_381_coordinate(blocks: &[u8]) -> blstrs::Fp {
301    debug_assert!(blocks.len() == 48);
302    let mut bytes = [0u8; 48];
303    bytes[..blocks.len()].copy_from_slice(&blocks[..blocks.len()]);
304
305    from_repr_with_reduction_bls12_381_coordinate(bytes)
306}
307
308#[inline(always)]
309pub fn field_element_to_blocks_bls12_381_coordinate<const BLOCK_SIZE: usize>(
310    field_element: &blstrs::Fp,
311    output: &mut [[u8; BLOCK_SIZE]],
312) {
313    debug_assert!(output.len() * BLOCK_SIZE == 48);
314    let bytes = field_element.to_bytes_le();
315    let mut byte_idx = 0;
316
317    for block in output.iter_mut() {
318        for byte in block.iter_mut() {
319            *byte = if byte_idx < bytes.len() {
320                bytes[byte_idx]
321            } else {
322                0
323            };
324            byte_idx += 1;
325        }
326    }
327}
328
329#[inline(always)]
330fn blocks_to_fp2_bn254<const BLOCKS: usize, const BLOCK_SIZE: usize>(
331    blocks: &[[u8; BLOCK_SIZE]],
332) -> halo2curves_axiom::bn256::Fq2 {
333    let c0 = blocks_to_field_element::<halo2curves_axiom::bn256::Fq>(
334        blocks[..BLOCKS / 2].as_flattened(),
335    );
336    let c1 = blocks_to_field_element::<halo2curves_axiom::bn256::Fq>(
337        blocks[BLOCKS / 2..].as_flattened(),
338    );
339    halo2curves_axiom::bn256::Fq2::new(c0, c1)
340}
341
342#[inline(always)]
343fn fp2_to_blocks_bn254<const BLOCKS: usize, const BLOCK_SIZE: usize>(
344    fp2: &halo2curves_axiom::bn256::Fq2,
345    output: &mut [[u8; BLOCK_SIZE]; BLOCKS],
346) {
347    field_element_to_blocks::<halo2curves_axiom::bn256::Fq, BLOCK_SIZE>(
348        &fp2.c0,
349        &mut output[..BLOCKS / 2],
350    );
351    field_element_to_blocks::<halo2curves_axiom::bn256::Fq, BLOCK_SIZE>(
352        &fp2.c1,
353        &mut output[BLOCKS / 2..],
354    );
355}
356
357#[inline(always)]
358fn blocks_to_fp2_bls12_381<const BLOCKS: usize, const BLOCK_SIZE: usize>(
359    blocks: &[[u8; BLOCK_SIZE]],
360) -> blstrs::Fp2 {
361    let c0 = blocks_to_field_element_bls12_381_coordinate(blocks[..BLOCKS / 2].as_flattened());
362    let c1 = blocks_to_field_element_bls12_381_coordinate(blocks[BLOCKS / 2..].as_flattened());
363    blstrs::Fp2::new(c0, c1)
364}
365
366#[inline(always)]
367fn fp2_to_blocks_bls12_381<const BLOCKS: usize, const BLOCK_SIZE: usize>(
368    fp2: &blstrs::Fp2,
369    output: &mut [[u8; BLOCK_SIZE]; BLOCKS],
370) {
371    field_element_to_blocks_bls12_381_coordinate(&fp2.c0(), &mut output[..BLOCKS / 2]);
372    field_element_to_blocks_bls12_381_coordinate(&fp2.c1(), &mut output[BLOCKS / 2..]);
373}