openvm_algebra_circuit/extension/
modular.rs

1use std::{array, sync::Arc};
2
3use num_bigint::{BigUint, RandBigInt};
4use num_traits::{FromPrimitive, One};
5use openvm_algebra_transpiler::{ModularPhantom, Rv32ModularArithmeticOpcode};
6use openvm_circuit::{
7    self,
8    arch::{
9        AirInventory, AirInventoryError, ChipInventory, ChipInventoryError, ExecutionBridge,
10        ExecutorInventoryBuilder, ExecutorInventoryError, RowMajorMatrixArena, VmCircuitExtension,
11        VmExecutionExtension, VmProverExtension,
12    },
13    system::{memory::SharedMemoryHelper, SystemPort},
14};
15use openvm_circuit_derive::{AnyEnum, Executor, MeteredExecutor, PreflightExecutor};
16use openvm_circuit_primitives::{
17    bigint::utils::big_uint_to_limbs,
18    bitwise_op_lookup::{
19        BitwiseOperationLookupAir, BitwiseOperationLookupBus, BitwiseOperationLookupChip,
20        SharedBitwiseOperationLookupChip,
21    },
22    var_range::VariableRangeCheckerBus,
23};
24use openvm_instructions::{LocalOpcode, PhantomDiscriminant, VmOpcode};
25use openvm_mod_circuit_builder::ExprBuilderConfig;
26use openvm_rv32_adapters::{
27    Rv32IsEqualModAdapterAir, Rv32IsEqualModAdapterExecutor, Rv32IsEqualModAdapterFiller,
28};
29use openvm_stark_backend::{
30    config::{StarkGenericConfig, Val},
31    p3_field::PrimeField32,
32    prover::cpu::{CpuBackend, CpuDevice},
33};
34use openvm_stark_sdk::engine::StarkEngine;
35use rand::Rng;
36use serde::{Deserialize, Serialize};
37use serde_with::{serde_as, DisplayFromStr};
38use strum::EnumCount;
39
40use crate::{
41    modular_chip::{
42        get_modular_addsub_air, get_modular_addsub_chip, get_modular_addsub_step,
43        get_modular_muldiv_air, get_modular_muldiv_chip, get_modular_muldiv_step, ModularAir,
44        ModularExecutor, ModularIsEqualAir, ModularIsEqualChip, ModularIsEqualCoreAir,
45        ModularIsEqualFiller, VmModularIsEqualExecutor,
46    },
47    AlgebraCpuProverExt,
48};
49
50#[serde_as]
51#[derive(Clone, Debug, derive_new::new, Serialize, Deserialize)]
52pub struct ModularExtension {
53    #[serde_as(as = "Vec<DisplayFromStr>")]
54    pub supported_moduli: Vec<BigUint>,
55}
56
57impl ModularExtension {
58    // Generates a call to the moduli_init! macro with moduli in the correct order
59    pub fn generate_moduli_init(&self) -> String {
60        let supported_moduli = self
61            .supported_moduli
62            .iter()
63            .map(|modulus| format!("\"{}\"", modulus))
64            .collect::<Vec<String>>()
65            .join(", ");
66
67        format!("openvm_algebra_guest::moduli_macros::moduli_init! {{ {supported_moduli} }}",)
68    }
69}
70
71#[derive(Clone, AnyEnum, Executor, MeteredExecutor, PreflightExecutor)]
72pub enum ModularExtensionExecutor {
73    // 32 limbs prime
74    ModularAddSubRv32_32(ModularExecutor<1, 32>), // ModularAddSub
75    ModularMulDivRv32_32(ModularExecutor<1, 32>), // ModularMulDiv
76    ModularIsEqualRv32_32(VmModularIsEqualExecutor<1, 32, 32>), // ModularIsEqual
77    // 48 limbs prime
78    ModularAddSubRv32_48(ModularExecutor<3, 16>), // ModularAddSub
79    ModularMulDivRv32_48(ModularExecutor<3, 16>), // ModularMulDiv
80    ModularIsEqualRv32_48(VmModularIsEqualExecutor<3, 16, 48>), // ModularIsEqual
81}
82
83impl<F: PrimeField32> VmExecutionExtension<F> for ModularExtension {
84    type Executor = ModularExtensionExecutor;
85
86    fn extend_execution(
87        &self,
88        inventory: &mut ExecutorInventoryBuilder<F, ModularExtensionExecutor>,
89    ) -> Result<(), ExecutorInventoryError> {
90        let pointer_max_bits = inventory.pointer_max_bits();
91        // TODO: somehow get the range checker bus from `ExecutorInventory`
92        let dummy_range_checker_bus = VariableRangeCheckerBus::new(u16::MAX, 16);
93        for (i, modulus) in self.supported_moduli.iter().enumerate() {
94            // determine the number of bytes needed to represent a prime field element
95            let bytes = modulus.bits().div_ceil(8);
96            let start_offset =
97                Rv32ModularArithmeticOpcode::CLASS_OFFSET + i * Rv32ModularArithmeticOpcode::COUNT;
98            let modulus_limbs = big_uint_to_limbs(modulus, 8);
99            if bytes <= 32 {
100                let config = ExprBuilderConfig {
101                    modulus: modulus.clone(),
102                    num_limbs: 32,
103                    limb_bits: 8,
104                };
105                let addsub = get_modular_addsub_step(
106                    config.clone(),
107                    dummy_range_checker_bus,
108                    pointer_max_bits,
109                    start_offset,
110                );
111
112                inventory.add_executor(
113                    ModularExtensionExecutor::ModularAddSubRv32_32(addsub),
114                    ((Rv32ModularArithmeticOpcode::ADD as usize)
115                        ..=(Rv32ModularArithmeticOpcode::SETUP_ADDSUB as usize))
116                        .map(|x| VmOpcode::from_usize(x + start_offset)),
117                )?;
118
119                let muldiv = get_modular_muldiv_step(
120                    config,
121                    dummy_range_checker_bus,
122                    pointer_max_bits,
123                    start_offset,
124                );
125
126                inventory.add_executor(
127                    ModularExtensionExecutor::ModularMulDivRv32_32(muldiv),
128                    ((Rv32ModularArithmeticOpcode::MUL as usize)
129                        ..=(Rv32ModularArithmeticOpcode::SETUP_MULDIV as usize))
130                        .map(|x| VmOpcode::from_usize(x + start_offset)),
131                )?;
132
133                let modulus_limbs = array::from_fn(|i| {
134                    if i < modulus_limbs.len() {
135                        modulus_limbs[i] as u8
136                    } else {
137                        0
138                    }
139                });
140
141                let is_eq = VmModularIsEqualExecutor::new(
142                    Rv32IsEqualModAdapterExecutor::new(pointer_max_bits),
143                    start_offset,
144                    modulus_limbs,
145                );
146
147                inventory.add_executor(
148                    ModularExtensionExecutor::ModularIsEqualRv32_32(is_eq),
149                    ((Rv32ModularArithmeticOpcode::IS_EQ as usize)
150                        ..=(Rv32ModularArithmeticOpcode::SETUP_ISEQ as usize))
151                        .map(|x| VmOpcode::from_usize(x + start_offset)),
152                )?;
153            } else if bytes <= 48 {
154                let config = ExprBuilderConfig {
155                    modulus: modulus.clone(),
156                    num_limbs: 48,
157                    limb_bits: 8,
158                };
159                let addsub = get_modular_addsub_step(
160                    config.clone(),
161                    dummy_range_checker_bus,
162                    pointer_max_bits,
163                    start_offset,
164                );
165
166                inventory.add_executor(
167                    ModularExtensionExecutor::ModularAddSubRv32_48(addsub),
168                    ((Rv32ModularArithmeticOpcode::ADD as usize)
169                        ..=(Rv32ModularArithmeticOpcode::SETUP_ADDSUB as usize))
170                        .map(|x| VmOpcode::from_usize(x + start_offset)),
171                )?;
172
173                let muldiv = get_modular_muldiv_step(
174                    config,
175                    dummy_range_checker_bus,
176                    pointer_max_bits,
177                    start_offset,
178                );
179
180                inventory.add_executor(
181                    ModularExtensionExecutor::ModularMulDivRv32_48(muldiv),
182                    ((Rv32ModularArithmeticOpcode::MUL as usize)
183                        ..=(Rv32ModularArithmeticOpcode::SETUP_MULDIV as usize))
184                        .map(|x| VmOpcode::from_usize(x + start_offset)),
185                )?;
186
187                let modulus_limbs = array::from_fn(|i| {
188                    if i < modulus_limbs.len() {
189                        modulus_limbs[i] as u8
190                    } else {
191                        0
192                    }
193                });
194
195                let is_eq = VmModularIsEqualExecutor::new(
196                    Rv32IsEqualModAdapterExecutor::new(pointer_max_bits),
197                    start_offset,
198                    modulus_limbs,
199                );
200
201                inventory.add_executor(
202                    ModularExtensionExecutor::ModularIsEqualRv32_48(is_eq),
203                    ((Rv32ModularArithmeticOpcode::IS_EQ as usize)
204                        ..=(Rv32ModularArithmeticOpcode::SETUP_ISEQ as usize))
205                        .map(|x| VmOpcode::from_usize(x + start_offset)),
206                )?;
207            } else {
208                panic!("Modulus too large");
209            }
210        }
211
212        let non_qr_hint_sub_ex = phantom::NonQrHintSubEx::new(self.supported_moduli.clone());
213        inventory.add_phantom_sub_executor(
214            non_qr_hint_sub_ex.clone(),
215            PhantomDiscriminant(ModularPhantom::HintNonQr as u16),
216        )?;
217
218        let sqrt_hint_sub_ex = phantom::SqrtHintSubEx::new(non_qr_hint_sub_ex);
219        inventory.add_phantom_sub_executor(
220            sqrt_hint_sub_ex,
221            PhantomDiscriminant(ModularPhantom::HintSqrt as u16),
222        )?;
223
224        Ok(())
225    }
226}
227
228impl<SC: StarkGenericConfig> VmCircuitExtension<SC> for ModularExtension {
229    fn extend_circuit(&self, inventory: &mut AirInventory<SC>) -> Result<(), AirInventoryError> {
230        let SystemPort {
231            execution_bus,
232            program_bus,
233            memory_bridge,
234        } = inventory.system().port();
235
236        let exec_bridge = ExecutionBridge::new(execution_bus, program_bus);
237        let range_checker_bus = inventory.range_checker().bus;
238        let pointer_max_bits = inventory.pointer_max_bits();
239
240        let bitwise_lu = {
241            // A trick to get around Rust's borrow rules
242            let existing_air = inventory.find_air::<BitwiseOperationLookupAir<8>>().next();
243            if let Some(air) = existing_air {
244                air.bus
245            } else {
246                let bus = BitwiseOperationLookupBus::new(inventory.new_bus_idx());
247                let air = BitwiseOperationLookupAir::<8>::new(bus);
248                inventory.add_air(air);
249                air.bus
250            }
251        };
252        for (i, modulus) in self.supported_moduli.iter().enumerate() {
253            // determine the number of bytes needed to represent a prime field element
254            let bytes = modulus.bits().div_ceil(8);
255            let start_offset =
256                Rv32ModularArithmeticOpcode::CLASS_OFFSET + i * Rv32ModularArithmeticOpcode::COUNT;
257
258            if bytes <= 32 {
259                let config = ExprBuilderConfig {
260                    modulus: modulus.clone(),
261                    num_limbs: 32,
262                    limb_bits: 8,
263                };
264
265                let addsub = get_modular_addsub_air::<1, 32>(
266                    exec_bridge,
267                    memory_bridge,
268                    config.clone(),
269                    range_checker_bus,
270                    bitwise_lu,
271                    pointer_max_bits,
272                    start_offset,
273                );
274                inventory.add_air(addsub);
275
276                let muldiv = get_modular_muldiv_air::<1, 32>(
277                    exec_bridge,
278                    memory_bridge,
279                    config,
280                    range_checker_bus,
281                    bitwise_lu,
282                    pointer_max_bits,
283                    start_offset,
284                );
285                inventory.add_air(muldiv);
286
287                let is_eq = ModularIsEqualAir::<1, 32, 32>::new(
288                    Rv32IsEqualModAdapterAir::new(
289                        exec_bridge,
290                        memory_bridge,
291                        bitwise_lu,
292                        pointer_max_bits,
293                    ),
294                    ModularIsEqualCoreAir::new(modulus.clone(), bitwise_lu, start_offset),
295                );
296                inventory.add_air(is_eq);
297            } else if bytes <= 48 {
298                let config = ExprBuilderConfig {
299                    modulus: modulus.clone(),
300                    num_limbs: 48,
301                    limb_bits: 8,
302                };
303
304                let addsub = get_modular_addsub_air::<3, 16>(
305                    exec_bridge,
306                    memory_bridge,
307                    config.clone(),
308                    range_checker_bus,
309                    bitwise_lu,
310                    pointer_max_bits,
311                    start_offset,
312                );
313                inventory.add_air(addsub);
314
315                let muldiv = get_modular_muldiv_air::<3, 16>(
316                    exec_bridge,
317                    memory_bridge,
318                    config,
319                    range_checker_bus,
320                    bitwise_lu,
321                    pointer_max_bits,
322                    start_offset,
323                );
324                inventory.add_air(muldiv);
325
326                let is_eq = ModularIsEqualAir::<3, 16, 48>::new(
327                    Rv32IsEqualModAdapterAir::new(
328                        exec_bridge,
329                        memory_bridge,
330                        bitwise_lu,
331                        pointer_max_bits,
332                    ),
333                    ModularIsEqualCoreAir::new(modulus.clone(), bitwise_lu, start_offset),
334                );
335                inventory.add_air(is_eq);
336            } else {
337                panic!("Modulus too large");
338            }
339        }
340
341        Ok(())
342    }
343}
344
345// This implementation is specific to CpuBackend because the lookup chips (VariableRangeChecker,
346// BitwiseOperationLookupChip) are specific to CpuBackend.
347impl<E, SC, RA> VmProverExtension<E, RA, ModularExtension> for AlgebraCpuProverExt
348where
349    SC: StarkGenericConfig,
350    E: StarkEngine<SC = SC, PB = CpuBackend<SC>, PD = CpuDevice<SC>>,
351    RA: RowMajorMatrixArena<Val<SC>>,
352    Val<SC>: PrimeField32,
353{
354    fn extend_prover(
355        &self,
356        extension: &ModularExtension,
357        inventory: &mut ChipInventory<SC, RA, CpuBackend<SC>>,
358    ) -> Result<(), ChipInventoryError> {
359        let range_checker = inventory.range_checker()?.clone();
360        let timestamp_max_bits = inventory.timestamp_max_bits();
361        let pointer_max_bits = inventory.airs().pointer_max_bits();
362        let mem_helper = SharedMemoryHelper::new(range_checker.clone(), timestamp_max_bits);
363        let bitwise_lu = {
364            let existing_chip = inventory
365                .find_chip::<SharedBitwiseOperationLookupChip<8>>()
366                .next();
367            if let Some(chip) = existing_chip {
368                chip.clone()
369            } else {
370                let air: &BitwiseOperationLookupAir<8> = inventory.next_air()?;
371                let chip = Arc::new(BitwiseOperationLookupChip::new(air.bus));
372                inventory.add_periphery_chip(chip.clone());
373                chip
374            }
375        };
376        for (i, modulus) in extension.supported_moduli.iter().enumerate() {
377            // determine the number of bytes needed to represent a prime field element
378            let bytes = modulus.bits().div_ceil(8);
379            let start_offset =
380                Rv32ModularArithmeticOpcode::CLASS_OFFSET + i * Rv32ModularArithmeticOpcode::COUNT;
381
382            let modulus_limbs = big_uint_to_limbs(modulus, 8);
383
384            if bytes <= 32 {
385                let config = ExprBuilderConfig {
386                    modulus: modulus.clone(),
387                    num_limbs: 32,
388                    limb_bits: 8,
389                };
390
391                inventory.next_air::<ModularAir<1, 32>>()?;
392                let addsub = get_modular_addsub_chip::<Val<SC>, 1, 32>(
393                    config.clone(),
394                    mem_helper.clone(),
395                    range_checker.clone(),
396                    bitwise_lu.clone(),
397                    pointer_max_bits,
398                );
399                inventory.add_executor_chip(addsub);
400
401                inventory.next_air::<ModularAir<1, 32>>()?;
402                let muldiv = get_modular_muldiv_chip::<Val<SC>, 1, 32>(
403                    config,
404                    mem_helper.clone(),
405                    range_checker.clone(),
406                    bitwise_lu.clone(),
407                    pointer_max_bits,
408                );
409                inventory.add_executor_chip(muldiv);
410
411                let modulus_limbs = array::from_fn(|i| {
412                    if i < modulus_limbs.len() {
413                        modulus_limbs[i] as u8
414                    } else {
415                        0
416                    }
417                });
418                inventory.next_air::<ModularIsEqualAir<1, 32, 32>>()?;
419                let is_eq = ModularIsEqualChip::<Val<SC>, 1, 32, 32>::new(
420                    ModularIsEqualFiller::new(
421                        Rv32IsEqualModAdapterFiller::new(pointer_max_bits, bitwise_lu.clone()),
422                        start_offset,
423                        modulus_limbs,
424                        bitwise_lu.clone(),
425                    ),
426                    mem_helper.clone(),
427                );
428                inventory.add_executor_chip(is_eq);
429            } else if bytes <= 48 {
430                let config = ExprBuilderConfig {
431                    modulus: modulus.clone(),
432                    num_limbs: 48,
433                    limb_bits: 8,
434                };
435
436                inventory.next_air::<ModularAir<3, 16>>()?;
437                let addsub = get_modular_addsub_chip::<Val<SC>, 3, 16>(
438                    config.clone(),
439                    mem_helper.clone(),
440                    range_checker.clone(),
441                    bitwise_lu.clone(),
442                    pointer_max_bits,
443                );
444                inventory.add_executor_chip(addsub);
445
446                inventory.next_air::<ModularAir<3, 16>>()?;
447                let muldiv = get_modular_muldiv_chip::<Val<SC>, 3, 16>(
448                    config,
449                    mem_helper.clone(),
450                    range_checker.clone(),
451                    bitwise_lu.clone(),
452                    pointer_max_bits,
453                );
454                inventory.add_executor_chip(muldiv);
455
456                let modulus_limbs = array::from_fn(|i| {
457                    if i < modulus_limbs.len() {
458                        modulus_limbs[i] as u8
459                    } else {
460                        0
461                    }
462                });
463                inventory.next_air::<ModularIsEqualAir<3, 16, 48>>()?;
464                let is_eq = ModularIsEqualChip::<Val<SC>, 3, 16, 48>::new(
465                    ModularIsEqualFiller::new(
466                        Rv32IsEqualModAdapterFiller::new(pointer_max_bits, bitwise_lu.clone()),
467                        start_offset,
468                        modulus_limbs,
469                        bitwise_lu.clone(),
470                    ),
471                    mem_helper.clone(),
472                );
473                inventory.add_executor_chip(is_eq);
474            } else {
475                panic!("Modulus too large");
476            }
477        }
478
479        Ok(())
480    }
481}
482
483pub(crate) mod phantom {
484    use std::{
485        iter::{once, repeat},
486        ops::Deref,
487    };
488
489    use eyre::bail;
490    use num_bigint::BigUint;
491    use openvm_circuit::{
492        arch::{PhantomSubExecutor, Streams},
493        system::memory::online::GuestMemory,
494    };
495    use openvm_instructions::{riscv::RV32_MEMORY_AS, PhantomDiscriminant};
496    use openvm_rv32im_circuit::adapters::read_rv32_register;
497    use openvm_stark_backend::p3_field::PrimeField32;
498    use rand::{rngs::StdRng, SeedableRng};
499
500    use super::{find_non_qr, mod_sqrt};
501
502    #[derive(derive_new::new)]
503    pub struct SqrtHintSubEx(NonQrHintSubEx);
504
505    impl Deref for SqrtHintSubEx {
506        type Target = NonQrHintSubEx;
507
508        fn deref(&self) -> &NonQrHintSubEx {
509            &self.0
510        }
511    }
512
513    // Given x returns either a sqrt of x or a sqrt of x * non_qr, whichever exists.
514    // Note that non_qr is fixed for each modulus.
515    impl<F: PrimeField32> PhantomSubExecutor<F> for SqrtHintSubEx {
516        fn phantom_execute(
517            &self,
518            memory: &GuestMemory,
519            streams: &mut Streams<F>,
520            _: &mut StdRng,
521            _: PhantomDiscriminant,
522            a: u32,
523            _: u32,
524            c_upper: u16,
525        ) -> eyre::Result<()> {
526            let mod_idx = c_upper as usize;
527            if mod_idx >= self.supported_moduli.len() {
528                bail!(
529                    "Modulus index {mod_idx} out of range: {} supported moduli",
530                    self.supported_moduli.len()
531                );
532            }
533            let modulus = &self.supported_moduli[mod_idx];
534            let num_limbs: usize = if modulus.bits().div_ceil(8) <= 32 {
535                32
536            } else if modulus.bits().div_ceil(8) <= 48 {
537                48
538            } else {
539                bail!("Modulus too large")
540            };
541
542            let rs1 = read_rv32_register(memory, a);
543            // SAFETY:
544            // - MEMORY_AS consists of `u8`s
545            // - MEMORY_AS is in bounds
546            let x_limbs: Vec<u8> =
547                unsafe { memory.memory.get_slice((RV32_MEMORY_AS, rs1), num_limbs) }.to_vec();
548            let x = BigUint::from_bytes_le(&x_limbs);
549
550            let (success, sqrt) = match mod_sqrt(&x, modulus, &self.non_qrs[mod_idx]) {
551                Some(sqrt) => (true, sqrt),
552                None => {
553                    let sqrt = mod_sqrt(
554                        &(&x * &self.non_qrs[mod_idx]),
555                        modulus,
556                        &self.non_qrs[mod_idx],
557                    )
558                    .expect("Either x or x * non_qr should be a square");
559                    (false, sqrt)
560                }
561            };
562
563            let hint_bytes = once(F::from_bool(success))
564                .chain(repeat(F::ZERO))
565                .take(4)
566                .chain(
567                    sqrt.to_bytes_le()
568                        .into_iter()
569                        .map(F::from_canonical_u8)
570                        .chain(repeat(F::ZERO))
571                        .take(num_limbs),
572                )
573                .collect();
574            streams.hint_stream = hint_bytes;
575            Ok(())
576        }
577    }
578
579    #[derive(Clone)]
580    pub struct NonQrHintSubEx {
581        pub supported_moduli: Vec<BigUint>,
582        pub non_qrs: Vec<BigUint>,
583    }
584
585    impl NonQrHintSubEx {
586        pub fn new(supported_moduli: Vec<BigUint>) -> Self {
587            // Use deterministic seed so that the non-QR are deterministic between different
588            // instances of the VM. The seed determines the runtime of Tonelli-Shanks, if the
589            // algorithm is necessary, which affects the time it takes to construct and initialize
590            // the VM but does not affect the runtime.
591            let mut rng = StdRng::from_seed([0u8; 32]);
592            let non_qrs = supported_moduli
593                .iter()
594                .map(|modulus| find_non_qr(modulus, &mut rng))
595                .collect();
596            Self {
597                supported_moduli,
598                non_qrs,
599            }
600        }
601    }
602
603    impl<F: PrimeField32> PhantomSubExecutor<F> for NonQrHintSubEx {
604        fn phantom_execute(
605            &self,
606            _: &GuestMemory,
607            streams: &mut Streams<F>,
608            _: &mut StdRng,
609            _: PhantomDiscriminant,
610            _: u32,
611            _: u32,
612            c_upper: u16,
613        ) -> eyre::Result<()> {
614            let mod_idx = c_upper as usize;
615            if mod_idx >= self.supported_moduli.len() {
616                bail!(
617                    "Modulus index {mod_idx} out of range: {} supported moduli",
618                    self.supported_moduli.len()
619                );
620            }
621            let modulus = &self.supported_moduli[mod_idx];
622
623            let num_limbs: usize = if modulus.bits().div_ceil(8) <= 32 {
624                32
625            } else if modulus.bits().div_ceil(8) <= 48 {
626                48
627            } else {
628                bail!("Modulus too large")
629            };
630
631            let hint_bytes = self.non_qrs[mod_idx]
632                .to_bytes_le()
633                .into_iter()
634                .map(F::from_canonical_u8)
635                .chain(repeat(F::ZERO))
636                .take(num_limbs)
637                .collect();
638            streams.hint_stream = hint_bytes;
639            Ok(())
640        }
641    }
642}
643
644/// Find the square root of `x` modulo `modulus` with `non_qr` a
645/// quadratic nonresidue of the field.
646pub fn mod_sqrt(x: &BigUint, modulus: &BigUint, non_qr: &BigUint) -> Option<BigUint> {
647    if modulus % 4u32 == BigUint::from_u8(3).unwrap() {
648        // x^(1/2) = x^((p+1)/4) when p = 3 mod 4
649        let exponent = (modulus + BigUint::one()) >> 2;
650        let ret = x.modpow(&exponent, modulus);
651        if &ret * &ret % modulus == x % modulus {
652            Some(ret)
653        } else {
654            None
655        }
656    } else {
657        // Tonelli-Shanks algorithm
658        // https://en.wikipedia.org/wiki/Tonelli%E2%80%93Shanks_algorithm#The_algorithm
659        let mut q = modulus - BigUint::one();
660        let mut s = 0;
661        while &q % 2u32 == BigUint::ZERO {
662            s += 1;
663            q /= 2u32;
664        }
665        let z = non_qr;
666        let mut m = s;
667        let mut c = z.modpow(&q, modulus);
668        let mut t = x.modpow(&q, modulus);
669        let mut r = x.modpow(&((q + BigUint::one()) >> 1), modulus);
670        loop {
671            if t == BigUint::ZERO {
672                return Some(BigUint::ZERO);
673            }
674            if t == BigUint::one() {
675                return Some(r);
676            }
677            let mut i = 0;
678            let mut tmp = t.clone();
679            while tmp != BigUint::one() && i < m {
680                tmp = &tmp * &tmp % modulus;
681                i += 1;
682            }
683            if i == m {
684                // self is not a quadratic residue
685                return None;
686            }
687            for _ in 0..m - i - 1 {
688                c = &c * &c % modulus;
689            }
690            let b = c;
691            m = i;
692            c = &b * &b % modulus;
693            t = ((t * &b % modulus) * &b) % modulus;
694            r = (r * b) % modulus;
695        }
696    }
697}
698
699// Returns a non-quadratic residue in the field
700pub fn find_non_qr(modulus: &BigUint, rng: &mut impl Rng) -> BigUint {
701    if modulus % 4u32 == BigUint::from(3u8) {
702        // p = 3 mod 4 then -1 is a quadratic residue
703        modulus - BigUint::one()
704    } else if modulus % 8u32 == BigUint::from(5u8) {
705        // p = 5 mod 8 then 2 is a non-quadratic residue
706        // since 2^((p-1)/2) = (-1)^((p^2-1)/8)
707        BigUint::from_u8(2u8).unwrap()
708    } else {
709        let mut non_qr = rng.gen_biguint_range(
710            &BigUint::from_u8(2).unwrap(),
711            &(modulus - BigUint::from_u8(1).unwrap()),
712        );
713        // To check if non_qr is a quadratic nonresidue, we compute non_qr^((p-1)/2)
714        // If the result is p-1, then non_qr is a quadratic nonresidue
715        // Otherwise, non_qr is a quadratic residue
716        let exponent = (modulus - BigUint::one()) >> 1;
717        while non_qr.modpow(&exponent, modulus) != modulus - BigUint::one() {
718            non_qr = rng.gen_biguint_range(
719                &BigUint::from_u8(2).unwrap(),
720                &(modulus - BigUint::from_u8(1).unwrap()),
721            );
722        }
723        non_qr
724    }
725}