openvm_algebra_circuit/extension/
modular.rs

1use std::{array, sync::Arc};
2
3use num_bigint::{BigUint, RandBigInt};
4use num_traits::{FromPrimitive, One};
5use openvm_algebra_transpiler::{ModularPhantom, Rv32ModularArithmeticOpcode};
6use openvm_circuit::{
7    self,
8    arch::{
9        AirInventory, AirInventoryError, ChipInventory, ChipInventoryError, ExecutionBridge,
10        ExecutorInventoryBuilder, ExecutorInventoryError, RowMajorMatrixArena, VmCircuitExtension,
11        VmExecutionExtension, VmProverExtension,
12    },
13    system::{memory::SharedMemoryHelper, SystemPort},
14};
15use openvm_circuit_derive::{AnyEnum, Executor, MeteredExecutor, PreflightExecutor};
16use openvm_circuit_primitives::{
17    bigint::utils::big_uint_to_limbs,
18    bitwise_op_lookup::{
19        BitwiseOperationLookupAir, BitwiseOperationLookupBus, BitwiseOperationLookupChip,
20        SharedBitwiseOperationLookupChip,
21    },
22    var_range::VariableRangeCheckerBus,
23};
24use openvm_instructions::{LocalOpcode, PhantomDiscriminant, VmOpcode};
25use openvm_mod_circuit_builder::ExprBuilderConfig;
26use openvm_rv32_adapters::{
27    Rv32IsEqualModAdapterAir, Rv32IsEqualModAdapterExecutor, Rv32IsEqualModAdapterFiller,
28};
29use openvm_stark_backend::{
30    config::{StarkGenericConfig, Val},
31    p3_field::PrimeField32,
32    prover::cpu::{CpuBackend, CpuDevice},
33};
34use openvm_stark_sdk::engine::StarkEngine;
35use rand::Rng;
36use serde::{Deserialize, Serialize};
37use serde_with::{serde_as, DisplayFromStr};
38use strum::EnumCount;
39
40use crate::{
41    modular_chip::{
42        get_modular_addsub_air, get_modular_addsub_chip, get_modular_addsub_step,
43        get_modular_muldiv_air, get_modular_muldiv_chip, get_modular_muldiv_step, ModularAir,
44        ModularExecutor, ModularIsEqualAir, ModularIsEqualChip, ModularIsEqualCoreAir,
45        ModularIsEqualFiller, VmModularIsEqualExecutor,
46    },
47    AlgebraCpuProverExt,
48};
49
50#[serde_as]
51#[derive(Clone, Debug, derive_new::new, Serialize, Deserialize)]
52pub struct ModularExtension {
53    #[serde_as(as = "Vec<DisplayFromStr>")]
54    pub supported_moduli: Vec<BigUint>,
55}
56
57impl ModularExtension {
58    // Generates a call to the moduli_init! macro with moduli in the correct order
59    pub fn generate_moduli_init(&self) -> String {
60        let supported_moduli = self
61            .supported_moduli
62            .iter()
63            .map(|modulus| format!("\"{modulus}\""))
64            .collect::<Vec<String>>()
65            .join(", ");
66
67        format!("openvm_algebra_guest::moduli_macros::moduli_init! {{ {supported_moduli} }}",)
68    }
69}
70
71#[derive(Clone, AnyEnum, Executor, MeteredExecutor, PreflightExecutor)]
72#[cfg_attr(
73    feature = "aot",
74    derive(
75        openvm_circuit_derive::AotExecutor,
76        openvm_circuit_derive::AotMeteredExecutor
77    )
78)]
79pub enum ModularExtensionExecutor {
80    // 32 limbs prime
81    ModularAddSubRv32_32(ModularExecutor<1, 32>), // ModularAddSub
82    ModularMulDivRv32_32(ModularExecutor<1, 32>), // ModularMulDiv
83    ModularIsEqualRv32_32(VmModularIsEqualExecutor<1, 32, 32>), // ModularIsEqual
84    // 48 limbs prime
85    ModularAddSubRv32_48(ModularExecutor<3, 16>), // ModularAddSub
86    ModularMulDivRv32_48(ModularExecutor<3, 16>), // ModularMulDiv
87    ModularIsEqualRv32_48(VmModularIsEqualExecutor<3, 16, 48>), // ModularIsEqual
88}
89
90impl<F: PrimeField32> VmExecutionExtension<F> for ModularExtension {
91    type Executor = ModularExtensionExecutor;
92
93    fn extend_execution(
94        &self,
95        inventory: &mut ExecutorInventoryBuilder<F, ModularExtensionExecutor>,
96    ) -> Result<(), ExecutorInventoryError> {
97        let pointer_max_bits = inventory.pointer_max_bits();
98        // TODO: somehow get the range checker bus from `ExecutorInventory`
99        let dummy_range_checker_bus = VariableRangeCheckerBus::new(u16::MAX, 16);
100        for (i, modulus) in self.supported_moduli.iter().enumerate() {
101            // determine the number of bytes needed to represent a prime field element
102            let bytes = modulus.bits().div_ceil(8);
103            let start_offset =
104                Rv32ModularArithmeticOpcode::CLASS_OFFSET + i * Rv32ModularArithmeticOpcode::COUNT;
105            let modulus_limbs = big_uint_to_limbs(modulus, 8);
106            if bytes <= 32 {
107                let config = ExprBuilderConfig {
108                    modulus: modulus.clone(),
109                    num_limbs: 32,
110                    limb_bits: 8,
111                };
112                let addsub = get_modular_addsub_step(
113                    config.clone(),
114                    dummy_range_checker_bus,
115                    pointer_max_bits,
116                    start_offset,
117                );
118
119                inventory.add_executor(
120                    ModularExtensionExecutor::ModularAddSubRv32_32(addsub),
121                    ((Rv32ModularArithmeticOpcode::ADD as usize)
122                        ..=(Rv32ModularArithmeticOpcode::SETUP_ADDSUB as usize))
123                        .map(|x| VmOpcode::from_usize(x + start_offset)),
124                )?;
125
126                let muldiv = get_modular_muldiv_step(
127                    config,
128                    dummy_range_checker_bus,
129                    pointer_max_bits,
130                    start_offset,
131                );
132
133                inventory.add_executor(
134                    ModularExtensionExecutor::ModularMulDivRv32_32(muldiv),
135                    ((Rv32ModularArithmeticOpcode::MUL as usize)
136                        ..=(Rv32ModularArithmeticOpcode::SETUP_MULDIV as usize))
137                        .map(|x| VmOpcode::from_usize(x + start_offset)),
138                )?;
139
140                let modulus_limbs = array::from_fn(|i| {
141                    if i < modulus_limbs.len() {
142                        modulus_limbs[i] as u8
143                    } else {
144                        0
145                    }
146                });
147
148                let is_eq = VmModularIsEqualExecutor::new(
149                    Rv32IsEqualModAdapterExecutor::new(pointer_max_bits),
150                    start_offset,
151                    modulus_limbs,
152                );
153
154                inventory.add_executor(
155                    ModularExtensionExecutor::ModularIsEqualRv32_32(is_eq),
156                    ((Rv32ModularArithmeticOpcode::IS_EQ as usize)
157                        ..=(Rv32ModularArithmeticOpcode::SETUP_ISEQ as usize))
158                        .map(|x| VmOpcode::from_usize(x + start_offset)),
159                )?;
160            } else if bytes <= 48 {
161                let config = ExprBuilderConfig {
162                    modulus: modulus.clone(),
163                    num_limbs: 48,
164                    limb_bits: 8,
165                };
166                let addsub = get_modular_addsub_step(
167                    config.clone(),
168                    dummy_range_checker_bus,
169                    pointer_max_bits,
170                    start_offset,
171                );
172
173                inventory.add_executor(
174                    ModularExtensionExecutor::ModularAddSubRv32_48(addsub),
175                    ((Rv32ModularArithmeticOpcode::ADD as usize)
176                        ..=(Rv32ModularArithmeticOpcode::SETUP_ADDSUB as usize))
177                        .map(|x| VmOpcode::from_usize(x + start_offset)),
178                )?;
179
180                let muldiv = get_modular_muldiv_step(
181                    config,
182                    dummy_range_checker_bus,
183                    pointer_max_bits,
184                    start_offset,
185                );
186
187                inventory.add_executor(
188                    ModularExtensionExecutor::ModularMulDivRv32_48(muldiv),
189                    ((Rv32ModularArithmeticOpcode::MUL as usize)
190                        ..=(Rv32ModularArithmeticOpcode::SETUP_MULDIV as usize))
191                        .map(|x| VmOpcode::from_usize(x + start_offset)),
192                )?;
193
194                let modulus_limbs = array::from_fn(|i| {
195                    if i < modulus_limbs.len() {
196                        modulus_limbs[i] as u8
197                    } else {
198                        0
199                    }
200                });
201
202                let is_eq = VmModularIsEqualExecutor::new(
203                    Rv32IsEqualModAdapterExecutor::new(pointer_max_bits),
204                    start_offset,
205                    modulus_limbs,
206                );
207
208                inventory.add_executor(
209                    ModularExtensionExecutor::ModularIsEqualRv32_48(is_eq),
210                    ((Rv32ModularArithmeticOpcode::IS_EQ as usize)
211                        ..=(Rv32ModularArithmeticOpcode::SETUP_ISEQ as usize))
212                        .map(|x| VmOpcode::from_usize(x + start_offset)),
213                )?;
214            } else {
215                panic!("Modulus too large");
216            }
217        }
218
219        let non_qr_hint_sub_ex = phantom::NonQrHintSubEx::new(self.supported_moduli.clone());
220        inventory.add_phantom_sub_executor(
221            non_qr_hint_sub_ex.clone(),
222            PhantomDiscriminant(ModularPhantom::HintNonQr as u16),
223        )?;
224
225        let sqrt_hint_sub_ex = phantom::SqrtHintSubEx::new(non_qr_hint_sub_ex);
226        inventory.add_phantom_sub_executor(
227            sqrt_hint_sub_ex,
228            PhantomDiscriminant(ModularPhantom::HintSqrt as u16),
229        )?;
230
231        Ok(())
232    }
233}
234
235impl<SC: StarkGenericConfig> VmCircuitExtension<SC> for ModularExtension {
236    fn extend_circuit(&self, inventory: &mut AirInventory<SC>) -> Result<(), AirInventoryError> {
237        let SystemPort {
238            execution_bus,
239            program_bus,
240            memory_bridge,
241        } = inventory.system().port();
242
243        let exec_bridge = ExecutionBridge::new(execution_bus, program_bus);
244        let range_checker_bus = inventory.range_checker().bus;
245        let pointer_max_bits = inventory.pointer_max_bits();
246
247        let bitwise_lu = {
248            // A trick to get around Rust's borrow rules
249            let existing_air = inventory.find_air::<BitwiseOperationLookupAir<8>>().next();
250            if let Some(air) = existing_air {
251                air.bus
252            } else {
253                let bus = BitwiseOperationLookupBus::new(inventory.new_bus_idx());
254                let air = BitwiseOperationLookupAir::<8>::new(bus);
255                inventory.add_air(air);
256                air.bus
257            }
258        };
259        for (i, modulus) in self.supported_moduli.iter().enumerate() {
260            // determine the number of bytes needed to represent a prime field element
261            let bytes = modulus.bits().div_ceil(8);
262            let start_offset =
263                Rv32ModularArithmeticOpcode::CLASS_OFFSET + i * Rv32ModularArithmeticOpcode::COUNT;
264
265            if bytes <= 32 {
266                let config = ExprBuilderConfig {
267                    modulus: modulus.clone(),
268                    num_limbs: 32,
269                    limb_bits: 8,
270                };
271
272                let addsub = get_modular_addsub_air::<1, 32>(
273                    exec_bridge,
274                    memory_bridge,
275                    config.clone(),
276                    range_checker_bus,
277                    bitwise_lu,
278                    pointer_max_bits,
279                    start_offset,
280                );
281                inventory.add_air(addsub);
282
283                let muldiv = get_modular_muldiv_air::<1, 32>(
284                    exec_bridge,
285                    memory_bridge,
286                    config,
287                    range_checker_bus,
288                    bitwise_lu,
289                    pointer_max_bits,
290                    start_offset,
291                );
292                inventory.add_air(muldiv);
293
294                let is_eq = ModularIsEqualAir::<1, 32, 32>::new(
295                    Rv32IsEqualModAdapterAir::new(
296                        exec_bridge,
297                        memory_bridge,
298                        bitwise_lu,
299                        pointer_max_bits,
300                    ),
301                    ModularIsEqualCoreAir::new(modulus.clone(), bitwise_lu, start_offset),
302                );
303                inventory.add_air(is_eq);
304            } else if bytes <= 48 {
305                let config = ExprBuilderConfig {
306                    modulus: modulus.clone(),
307                    num_limbs: 48,
308                    limb_bits: 8,
309                };
310
311                let addsub = get_modular_addsub_air::<3, 16>(
312                    exec_bridge,
313                    memory_bridge,
314                    config.clone(),
315                    range_checker_bus,
316                    bitwise_lu,
317                    pointer_max_bits,
318                    start_offset,
319                );
320                inventory.add_air(addsub);
321
322                let muldiv = get_modular_muldiv_air::<3, 16>(
323                    exec_bridge,
324                    memory_bridge,
325                    config,
326                    range_checker_bus,
327                    bitwise_lu,
328                    pointer_max_bits,
329                    start_offset,
330                );
331                inventory.add_air(muldiv);
332
333                let is_eq = ModularIsEqualAir::<3, 16, 48>::new(
334                    Rv32IsEqualModAdapterAir::new(
335                        exec_bridge,
336                        memory_bridge,
337                        bitwise_lu,
338                        pointer_max_bits,
339                    ),
340                    ModularIsEqualCoreAir::new(modulus.clone(), bitwise_lu, start_offset),
341                );
342                inventory.add_air(is_eq);
343            } else {
344                panic!("Modulus too large");
345            }
346        }
347
348        Ok(())
349    }
350}
351
352// This implementation is specific to CpuBackend because the lookup chips (VariableRangeChecker,
353// BitwiseOperationLookupChip) are specific to CpuBackend.
354impl<E, SC, RA> VmProverExtension<E, RA, ModularExtension> for AlgebraCpuProverExt
355where
356    SC: StarkGenericConfig,
357    E: StarkEngine<SC = SC, PB = CpuBackend<SC>, PD = CpuDevice<SC>>,
358    RA: RowMajorMatrixArena<Val<SC>>,
359    Val<SC>: PrimeField32,
360{
361    fn extend_prover(
362        &self,
363        extension: &ModularExtension,
364        inventory: &mut ChipInventory<SC, RA, CpuBackend<SC>>,
365    ) -> Result<(), ChipInventoryError> {
366        let range_checker = inventory.range_checker()?.clone();
367        let timestamp_max_bits = inventory.timestamp_max_bits();
368        let pointer_max_bits = inventory.airs().pointer_max_bits();
369        let mem_helper = SharedMemoryHelper::new(range_checker.clone(), timestamp_max_bits);
370        let bitwise_lu = {
371            let existing_chip = inventory
372                .find_chip::<SharedBitwiseOperationLookupChip<8>>()
373                .next();
374            if let Some(chip) = existing_chip {
375                chip.clone()
376            } else {
377                let air: &BitwiseOperationLookupAir<8> = inventory.next_air()?;
378                let chip = Arc::new(BitwiseOperationLookupChip::new(air.bus));
379                inventory.add_periphery_chip(chip.clone());
380                chip
381            }
382        };
383        for (i, modulus) in extension.supported_moduli.iter().enumerate() {
384            // determine the number of bytes needed to represent a prime field element
385            let bytes = modulus.bits().div_ceil(8);
386            let start_offset =
387                Rv32ModularArithmeticOpcode::CLASS_OFFSET + i * Rv32ModularArithmeticOpcode::COUNT;
388
389            let modulus_limbs = big_uint_to_limbs(modulus, 8);
390
391            if bytes <= 32 {
392                let config = ExprBuilderConfig {
393                    modulus: modulus.clone(),
394                    num_limbs: 32,
395                    limb_bits: 8,
396                };
397
398                inventory.next_air::<ModularAir<1, 32>>()?;
399                let addsub = get_modular_addsub_chip::<Val<SC>, 1, 32>(
400                    config.clone(),
401                    mem_helper.clone(),
402                    range_checker.clone(),
403                    bitwise_lu.clone(),
404                    pointer_max_bits,
405                );
406                inventory.add_executor_chip(addsub);
407
408                inventory.next_air::<ModularAir<1, 32>>()?;
409                let muldiv = get_modular_muldiv_chip::<Val<SC>, 1, 32>(
410                    config,
411                    mem_helper.clone(),
412                    range_checker.clone(),
413                    bitwise_lu.clone(),
414                    pointer_max_bits,
415                );
416                inventory.add_executor_chip(muldiv);
417
418                let modulus_limbs = array::from_fn(|i| {
419                    if i < modulus_limbs.len() {
420                        modulus_limbs[i] as u8
421                    } else {
422                        0
423                    }
424                });
425                inventory.next_air::<ModularIsEqualAir<1, 32, 32>>()?;
426                let is_eq = ModularIsEqualChip::<Val<SC>, 1, 32, 32>::new(
427                    ModularIsEqualFiller::new(
428                        Rv32IsEqualModAdapterFiller::new(pointer_max_bits, bitwise_lu.clone()),
429                        start_offset,
430                        modulus_limbs,
431                        bitwise_lu.clone(),
432                    ),
433                    mem_helper.clone(),
434                );
435                inventory.add_executor_chip(is_eq);
436            } else if bytes <= 48 {
437                let config = ExprBuilderConfig {
438                    modulus: modulus.clone(),
439                    num_limbs: 48,
440                    limb_bits: 8,
441                };
442
443                inventory.next_air::<ModularAir<3, 16>>()?;
444                let addsub = get_modular_addsub_chip::<Val<SC>, 3, 16>(
445                    config.clone(),
446                    mem_helper.clone(),
447                    range_checker.clone(),
448                    bitwise_lu.clone(),
449                    pointer_max_bits,
450                );
451                inventory.add_executor_chip(addsub);
452
453                inventory.next_air::<ModularAir<3, 16>>()?;
454                let muldiv = get_modular_muldiv_chip::<Val<SC>, 3, 16>(
455                    config,
456                    mem_helper.clone(),
457                    range_checker.clone(),
458                    bitwise_lu.clone(),
459                    pointer_max_bits,
460                );
461                inventory.add_executor_chip(muldiv);
462
463                let modulus_limbs = array::from_fn(|i| {
464                    if i < modulus_limbs.len() {
465                        modulus_limbs[i] as u8
466                    } else {
467                        0
468                    }
469                });
470                inventory.next_air::<ModularIsEqualAir<3, 16, 48>>()?;
471                let is_eq = ModularIsEqualChip::<Val<SC>, 3, 16, 48>::new(
472                    ModularIsEqualFiller::new(
473                        Rv32IsEqualModAdapterFiller::new(pointer_max_bits, bitwise_lu.clone()),
474                        start_offset,
475                        modulus_limbs,
476                        bitwise_lu.clone(),
477                    ),
478                    mem_helper.clone(),
479                );
480                inventory.add_executor_chip(is_eq);
481            } else {
482                panic!("Modulus too large");
483            }
484        }
485
486        Ok(())
487    }
488}
489
490pub(crate) mod phantom {
491    use std::{
492        iter::{once, repeat},
493        ops::Deref,
494    };
495
496    use eyre::bail;
497    use num_bigint::BigUint;
498    use openvm_circuit::{
499        arch::{PhantomSubExecutor, Streams},
500        system::memory::online::GuestMemory,
501    };
502    use openvm_instructions::{riscv::RV32_MEMORY_AS, PhantomDiscriminant};
503    use openvm_rv32im_circuit::adapters::read_rv32_register;
504    use openvm_stark_backend::p3_field::PrimeField32;
505    use rand::{rngs::StdRng, SeedableRng};
506
507    use super::{find_non_qr, mod_sqrt};
508
509    #[derive(derive_new::new)]
510    pub struct SqrtHintSubEx(NonQrHintSubEx);
511
512    impl Deref for SqrtHintSubEx {
513        type Target = NonQrHintSubEx;
514
515        fn deref(&self) -> &NonQrHintSubEx {
516            &self.0
517        }
518    }
519
520    // Given x returns either a sqrt of x or a sqrt of x * non_qr, whichever exists.
521    // Note that non_qr is fixed for each modulus.
522    impl<F: PrimeField32> PhantomSubExecutor<F> for SqrtHintSubEx {
523        fn phantom_execute(
524            &self,
525            memory: &GuestMemory,
526            streams: &mut Streams<F>,
527            _: &mut StdRng,
528            _: PhantomDiscriminant,
529            a: u32,
530            _: u32,
531            c_upper: u16,
532        ) -> eyre::Result<()> {
533            let mod_idx = c_upper as usize;
534            if mod_idx >= self.supported_moduli.len() {
535                bail!(
536                    "Modulus index {mod_idx} out of range: {} supported moduli",
537                    self.supported_moduli.len()
538                );
539            }
540            let modulus = &self.supported_moduli[mod_idx];
541            let num_limbs: usize = if modulus.bits().div_ceil(8) <= 32 {
542                32
543            } else if modulus.bits().div_ceil(8) <= 48 {
544                48
545            } else {
546                bail!("Modulus too large")
547            };
548
549            let rs1 = read_rv32_register(memory, a);
550            // SAFETY:
551            // - MEMORY_AS consists of `u8`s
552            // - MEMORY_AS is in bounds
553            let x_limbs: Vec<u8> =
554                unsafe { memory.memory.get_slice((RV32_MEMORY_AS, rs1), num_limbs) }.to_vec();
555            let x = BigUint::from_bytes_le(&x_limbs);
556
557            let (success, sqrt) = match mod_sqrt(&x, modulus, &self.non_qrs[mod_idx]) {
558                Some(sqrt) => (true, sqrt),
559                None => {
560                    let sqrt = mod_sqrt(
561                        &(&x * &self.non_qrs[mod_idx]),
562                        modulus,
563                        &self.non_qrs[mod_idx],
564                    )
565                    .expect("Either x or x * non_qr should be a square");
566                    (false, sqrt)
567                }
568            };
569
570            let hint_bytes = once(F::from_bool(success))
571                .chain(repeat(F::ZERO))
572                .take(4)
573                .chain(
574                    sqrt.to_bytes_le()
575                        .into_iter()
576                        .map(F::from_canonical_u8)
577                        .chain(repeat(F::ZERO))
578                        .take(num_limbs),
579                )
580                .collect();
581            streams.hint_stream = hint_bytes;
582            Ok(())
583        }
584    }
585
586    #[derive(Clone)]
587    pub struct NonQrHintSubEx {
588        pub supported_moduli: Vec<BigUint>,
589        pub non_qrs: Vec<BigUint>,
590    }
591
592    impl NonQrHintSubEx {
593        pub fn new(supported_moduli: Vec<BigUint>) -> Self {
594            // Use deterministic seed so that the non-QR are deterministic between different
595            // instances of the VM. The seed determines the runtime of Tonelli-Shanks, if the
596            // algorithm is necessary, which affects the time it takes to construct and initialize
597            // the VM but does not affect the runtime.
598            let mut rng = StdRng::from_seed([0u8; 32]);
599            let non_qrs = supported_moduli
600                .iter()
601                .map(|modulus| find_non_qr(modulus, &mut rng))
602                .collect();
603            Self {
604                supported_moduli,
605                non_qrs,
606            }
607        }
608    }
609
610    impl<F: PrimeField32> PhantomSubExecutor<F> for NonQrHintSubEx {
611        fn phantom_execute(
612            &self,
613            _: &GuestMemory,
614            streams: &mut Streams<F>,
615            _: &mut StdRng,
616            _: PhantomDiscriminant,
617            _: u32,
618            _: u32,
619            c_upper: u16,
620        ) -> eyre::Result<()> {
621            let mod_idx = c_upper as usize;
622            if mod_idx >= self.supported_moduli.len() {
623                bail!(
624                    "Modulus index {mod_idx} out of range: {} supported moduli",
625                    self.supported_moduli.len()
626                );
627            }
628            let modulus = &self.supported_moduli[mod_idx];
629
630            let num_limbs: usize = if modulus.bits().div_ceil(8) <= 32 {
631                32
632            } else if modulus.bits().div_ceil(8) <= 48 {
633                48
634            } else {
635                bail!("Modulus too large")
636            };
637
638            let hint_bytes = self.non_qrs[mod_idx]
639                .to_bytes_le()
640                .into_iter()
641                .map(F::from_canonical_u8)
642                .chain(repeat(F::ZERO))
643                .take(num_limbs)
644                .collect();
645            streams.hint_stream = hint_bytes;
646            Ok(())
647        }
648    }
649}
650
651/// Find the square root of `x` modulo `modulus` with `non_qr` a
652/// quadratic nonresidue of the field.
653pub fn mod_sqrt(x: &BigUint, modulus: &BigUint, non_qr: &BigUint) -> Option<BigUint> {
654    if modulus % 4u32 == BigUint::from_u8(3).unwrap() {
655        // x^(1/2) = x^((p+1)/4) when p = 3 mod 4
656        let exponent = (modulus + BigUint::one()) >> 2;
657        let ret = x.modpow(&exponent, modulus);
658        if &ret * &ret % modulus == x % modulus {
659            Some(ret)
660        } else {
661            None
662        }
663    } else {
664        // Tonelli-Shanks algorithm
665        // https://en.wikipedia.org/wiki/Tonelli%E2%80%93Shanks_algorithm#The_algorithm
666        let mut q = modulus - BigUint::one();
667        let mut s = 0;
668        while &q % 2u32 == BigUint::ZERO {
669            s += 1;
670            q /= 2u32;
671        }
672        let z = non_qr;
673        let mut m = s;
674        let mut c = z.modpow(&q, modulus);
675        let mut t = x.modpow(&q, modulus);
676        let mut r = x.modpow(&((q + BigUint::one()) >> 1), modulus);
677        loop {
678            if t == BigUint::ZERO {
679                return Some(BigUint::ZERO);
680            }
681            if t == BigUint::one() {
682                return Some(r);
683            }
684            let mut i = 0;
685            let mut tmp = t.clone();
686            while tmp != BigUint::one() && i < m {
687                tmp = &tmp * &tmp % modulus;
688                i += 1;
689            }
690            if i == m {
691                // self is not a quadratic residue
692                return None;
693            }
694            for _ in 0..m - i - 1 {
695                c = &c * &c % modulus;
696            }
697            let b = c;
698            m = i;
699            c = &b * &b % modulus;
700            t = ((t * &b % modulus) * &b) % modulus;
701            r = (r * b) % modulus;
702        }
703    }
704}
705
706// Returns a non-quadratic residue in the field
707pub fn find_non_qr(modulus: &BigUint, rng: &mut impl Rng) -> BigUint {
708    if modulus % 4u32 == BigUint::from(3u8) {
709        // p = 3 mod 4 then -1 is a quadratic residue
710        modulus - BigUint::one()
711    } else if modulus % 8u32 == BigUint::from(5u8) {
712        // p = 5 mod 8 then 2 is a non-quadratic residue
713        // since 2^((p-1)/2) = (-1)^((p^2-1)/8)
714        BigUint::from_u8(2u8).unwrap()
715    } else {
716        let mut non_qr = rng.gen_biguint_range(
717            &BigUint::from_u8(2).unwrap(),
718            &(modulus - BigUint::from_u8(1).unwrap()),
719        );
720        // To check if non_qr is a quadratic nonresidue, we compute non_qr^((p-1)/2)
721        // If the result is p-1, then non_qr is a quadratic nonresidue
722        // Otherwise, non_qr is a quadratic residue
723        let exponent = (modulus - BigUint::one()) >> 1;
724        while non_qr.modpow(&exponent, modulus) != modulus - BigUint::one() {
725            non_qr = rng.gen_biguint_range(
726                &BigUint::from_u8(2).unwrap(),
727                &(modulus - BigUint::from_u8(1).unwrap()),
728            );
729        }
730        non_qr
731    }
732}