openvm_native_compiler/asm/
compiler.rs

1use alloc::{collections::BTreeMap, vec};
2
3use openvm_circuit::arch::instructions::instruction::DebugInfo;
4use openvm_stark_backend::p3_field::{ExtensionField, Field, PrimeField32, TwoAdicField};
5
6use super::{config::AsmConfig, AssemblyCode, BasicBlock, IndexTriple, ValueOrConst};
7use crate::{
8    asm::AsmInstruction,
9    ir::{Array, DslIr, Ext, Felt, Ptr, RVar, Usize, Var},
10    prelude::TracedVec,
11};
12
13pub const MEMORY_BITS: usize = 29;
14/// The memory location for the top of memory
15pub const MEMORY_TOP: u32 = (1 << MEMORY_BITS) - 4;
16
17// The memory location for the start of the heap.
18pub(crate) const HEAP_START_ADDRESS: i32 = 1 << 24;
19
20/// The heap pointer address.
21pub(crate) const HEAP_PTR: i32 = HEAP_START_ADDRESS - 4;
22
23const HEAP_SIZE: u32 = MEMORY_TOP - HEAP_START_ADDRESS as u32;
24
25/// Utility register.
26pub const A0: i32 = HEAP_START_ADDRESS - 8;
27
28/// The memory location for the top of the stack.
29pub(crate) const STACK_TOP: i32 = HEAP_START_ADDRESS - 64;
30
31/// The assembly compiler.
32// #[derive(Debug, Clone, Default)]
33pub struct AsmCompiler<F, EF> {
34    basic_blocks: Vec<BasicBlock<F, EF>>,
35    function_labels: BTreeMap<String, F>,
36    trap_label: F,
37    word_size: usize,
38}
39
40impl<F> Var<F> {
41    /// Gets the frame pointer for a var.
42    pub const fn fp(&self) -> i32 {
43        // Vars are stored in stack positions 1, 2, 9, 10, 17, 18, ...
44        let offset = (8 * (self.0 / 2) + 1 + (self.0 % 2)) as i32;
45        assert!(offset < STACK_TOP, "Var fp overflow");
46        STACK_TOP - offset
47    }
48}
49
50impl<F> Felt<F> {
51    /// Gets the frame pointer for a felt.
52    pub const fn fp(&self) -> i32 {
53        // Felts are stored in stack positions 3, 4, 11, 12, 19, 20, ...
54        let offset = (((self.0 >> 1) << 3) + 3 + (self.0 & 1)) as i32;
55        assert!(offset < STACK_TOP, "Felt fp overflow");
56        STACK_TOP - offset
57    }
58}
59
60impl<F, EF> Ext<F, EF> {
61    /// Gets the frame pointer for an extension element
62    pub const fn fp(&self) -> i32 {
63        // Exts are stored in stack positions 5-8, 13-16, 21-24, ...
64        let offset = 8 * self.0 as i32;
65        assert!(offset < STACK_TOP, "Ext fp overflow");
66        STACK_TOP - 8 * self.0 as i32
67    }
68}
69
70impl<F> Ptr<F> {
71    /// Gets the frame pointer for a pointer.
72    pub const fn fp(&self) -> i32 {
73        self.address.fp()
74    }
75}
76
77impl<F: PrimeField32 + TwoAdicField, EF: ExtensionField<F> + TwoAdicField> AsmCompiler<F, EF> {
78    /// Creates a new [AsmCompiler].
79    pub fn new(word_size: usize) -> Self {
80        Self {
81            basic_blocks: vec![BasicBlock::new()],
82            function_labels: BTreeMap::new(),
83            trap_label: F::ONE,
84            word_size,
85        }
86    }
87
88    /// Builds the operations into assembly instructions.
89    pub fn build(&mut self, operations: TracedVec<DslIr<AsmConfig<F, EF>>>) {
90        if self.block_label().is_zero() {
91            // Initialize the heap pointer value.
92            let heap_start = F::from_canonical_u32(HEAP_START_ADDRESS as u32);
93            self.push(AsmInstruction::ImmF(HEAP_PTR, heap_start), None);
94            // Jump over the TRAP instruction we are about to add.
95            self.push(AsmInstruction::j(self.trap_label + F::ONE), None);
96            self.basic_block();
97            // Add a TRAP instruction used as jump destination for all failed assertions.
98            assert_eq!(self.block_label(), self.trap_label);
99            self.push(AsmInstruction::Trap, None);
100            self.basic_block();
101        }
102        // For each operation, generate assembly instructions.
103        for (op, trace) in operations.clone() {
104            let debug_info = Some(DebugInfo::new(op.to_string(), trace));
105            match op {
106                DslIr::ImmV(dst, src) => {
107                    self.push(AsmInstruction::ImmF(dst.fp(), src), debug_info);
108                }
109                DslIr::ImmF(dst, src) => {
110                    self.push(AsmInstruction::ImmF(dst.fp(), src), debug_info);
111                }
112                DslIr::ImmE(dst, src) => {
113                    self.assign_exti(dst.fp(), src, debug_info);
114                }
115                DslIr::AddV(dst, lhs, rhs) => {
116                    self.push(
117                        AsmInstruction::AddF(dst.fp(), lhs.fp(), rhs.fp()),
118                        debug_info,
119                    );
120                }
121                DslIr::AddVI(dst, lhs, rhs) => {
122                    self.push(AsmInstruction::AddFI(dst.fp(), lhs.fp(), rhs), debug_info);
123                }
124                DslIr::AddF(dst, lhs, rhs) => {
125                    self.push(
126                        AsmInstruction::AddF(dst.fp(), lhs.fp(), rhs.fp()),
127                        debug_info,
128                    );
129                }
130                DslIr::AddFI(dst, lhs, rhs) => {
131                    self.push(AsmInstruction::AddFI(dst.fp(), lhs.fp(), rhs), debug_info);
132                }
133                DslIr::AddE(dst, lhs, rhs) => {
134                    self.push(
135                        AsmInstruction::AddE(dst.fp(), lhs.fp(), rhs.fp()),
136                        debug_info,
137                    );
138                }
139                DslIr::AddEI(dst, lhs, rhs) => {
140                    self.add_ext_exti(dst, lhs, rhs, debug_info);
141                }
142                DslIr::AddEF(dst, lhs, rhs) => {
143                    self.add_ext_felt(dst, lhs, rhs, debug_info);
144                }
145                DslIr::AddEFFI(dst, lhs, rhs) => {
146                    self.add_felt_exti(dst, lhs, rhs, debug_info);
147                }
148                DslIr::AddEFI(dst, lhs, rhs) => {
149                    self.add_ext_exti(dst, lhs, EF::from_base(rhs), debug_info);
150                }
151                DslIr::SubV(dst, lhs, rhs) => {
152                    self.push(
153                        AsmInstruction::SubF(dst.fp(), lhs.fp(), rhs.fp()),
154                        debug_info,
155                    );
156                }
157                DslIr::SubVI(dst, lhs, rhs) => {
158                    self.push(AsmInstruction::SubFI(dst.fp(), lhs.fp(), rhs), debug_info);
159                }
160                DslIr::SubVIN(dst, lhs, rhs) => {
161                    self.push(
162                        AsmInstruction::SubFIN(dst.fp(), lhs, rhs.fp()),
163                        debug_info.clone(),
164                    );
165                }
166                DslIr::SubF(dst, lhs, rhs) => {
167                    self.push(
168                        AsmInstruction::SubF(dst.fp(), lhs.fp(), rhs.fp()),
169                        debug_info,
170                    );
171                }
172                DslIr::SubFI(dst, lhs, rhs) => {
173                    self.push(AsmInstruction::SubFI(dst.fp(), lhs.fp(), rhs), debug_info);
174                }
175                DslIr::SubFIN(dst, lhs, rhs) => {
176                    self.push(
177                        AsmInstruction::SubFIN(dst.fp(), lhs, rhs.fp()),
178                        debug_info.clone(),
179                    );
180                }
181                DslIr::NegV(dst, src) => {
182                    self.push(
183                        AsmInstruction::MulFI(dst.fp(), src.fp(), F::NEG_ONE),
184                        debug_info,
185                    );
186                }
187                DslIr::NegF(dst, src) => {
188                    self.push(
189                        AsmInstruction::MulFI(dst.fp(), src.fp(), F::NEG_ONE),
190                        debug_info,
191                    );
192                }
193                DslIr::DivF(dst, lhs, rhs) => {
194                    self.push(
195                        AsmInstruction::DivF(dst.fp(), lhs.fp(), rhs.fp()),
196                        debug_info,
197                    );
198                }
199                DslIr::DivFI(dst, lhs, rhs) => {
200                    self.push(AsmInstruction::DivFI(dst.fp(), lhs.fp(), rhs), debug_info);
201                }
202                DslIr::DivFIN(dst, lhs, rhs) => {
203                    self.push(AsmInstruction::DivFIN(dst.fp(), lhs, rhs.fp()), debug_info);
204                }
205                DslIr::DivEIN(dst, lhs, rhs) => {
206                    self.assign_exti(A0, lhs, debug_info.clone());
207                    self.push(AsmInstruction::DivE(dst.fp(), A0, rhs.fp()), debug_info);
208                }
209                DslIr::DivE(dst, lhs, rhs) => {
210                    self.push(
211                        AsmInstruction::DivE(dst.fp(), lhs.fp(), rhs.fp()),
212                        debug_info,
213                    );
214                }
215                DslIr::DivEI(dst, lhs, rhs) => {
216                    self.assign_exti(A0, rhs, debug_info.clone());
217                    self.push(AsmInstruction::DivE(dst.fp(), lhs.fp(), A0), debug_info);
218                }
219                DslIr::DivEF(dst, lhs, rhs) => {
220                    self.div_ext_felt(dst, lhs, rhs, debug_info);
221                }
222                DslIr::DivEFI(dst, lhs, rhs) => {
223                    self.mul_ext_felti(dst, lhs, rhs.inverse(), debug_info);
224                }
225                DslIr::SubEF(dst, lhs, rhs) => {
226                    self.sub_ext_felt(dst, lhs, rhs, debug_info);
227                }
228                DslIr::SubEFI(dst, lhs, rhs) => {
229                    self.add_ext_exti(dst, lhs, EF::from_base(rhs.neg()), debug_info);
230                }
231                DslIr::SubEIN(dst, lhs, rhs) => {
232                    self.sub_exti_ext(dst, lhs, rhs, debug_info.clone());
233                }
234                DslIr::SubE(dst, lhs, rhs) => {
235                    self.push(
236                        AsmInstruction::SubE(dst.fp(), lhs.fp(), rhs.fp()),
237                        debug_info,
238                    );
239                }
240                DslIr::SubEI(dst, lhs, rhs) => {
241                    self.add_ext_exti(dst, lhs, rhs.neg(), debug_info);
242                }
243                DslIr::NegE(dst, src) => {
244                    self.mul_ext_felti(dst, src, F::NEG_ONE, debug_info);
245                }
246                DslIr::MulV(dst, lhs, rhs) => {
247                    self.push(
248                        AsmInstruction::MulF(dst.fp(), lhs.fp(), rhs.fp()),
249                        debug_info,
250                    );
251                }
252                DslIr::MulVI(dst, lhs, rhs) => {
253                    self.push(AsmInstruction::MulFI(dst.fp(), lhs.fp(), rhs), debug_info);
254                }
255                DslIr::MulF(dst, lhs, rhs) => {
256                    self.push(
257                        AsmInstruction::MulF(dst.fp(), lhs.fp(), rhs.fp()),
258                        debug_info,
259                    );
260                }
261                DslIr::MulFI(dst, lhs, rhs) => {
262                    self.push(AsmInstruction::MulFI(dst.fp(), lhs.fp(), rhs), debug_info);
263                }
264                DslIr::MulE(dst, lhs, rhs) => {
265                    self.push(
266                        AsmInstruction::MulE(dst.fp(), lhs.fp(), rhs.fp()),
267                        debug_info,
268                    );
269                }
270                DslIr::MulEI(dst, lhs, rhs) => {
271                    self.assign_exti(A0, rhs, debug_info.clone());
272                    self.push(AsmInstruction::MulE(dst.fp(), lhs.fp(), A0), debug_info);
273                }
274                DslIr::MulEF(dst, lhs, rhs) => {
275                    self.mul_ext_felt(dst, lhs, rhs, debug_info);
276                }
277                DslIr::MulEFI(dst, lhs, rhs) => {
278                    self.mul_ext_felti(dst, lhs, rhs, debug_info);
279                }
280                DslIr::CastFV(dst, src) => {
281                    self.push(
282                        AsmInstruction::AddFI(dst.fp(), src.fp(), F::ZERO),
283                        debug_info,
284                    );
285                }
286                DslIr::UnsafeCastVF(dst, src) => {
287                    self.push(
288                        AsmInstruction::AddFI(dst.fp(), src.fp(), F::ZERO),
289                        debug_info,
290                    );
291                }
292                DslIr::IfEq(lhs, rhs, then_block, else_block) => {
293                    let if_compiler = IfCompiler {
294                        compiler: self,
295                        lhs: lhs.fp(),
296                        rhs: ValueOrConst::Val(rhs.fp()),
297                        is_eq: true,
298                    };
299                    if else_block.is_empty() {
300                        if_compiler.then(|builder| builder.build(then_block), debug_info);
301                    } else {
302                        if_compiler.then_or_else(
303                            |builder| builder.build(then_block),
304                            |builder| builder.build(else_block),
305                            debug_info,
306                        );
307                    }
308                }
309                DslIr::IfNe(lhs, rhs, then_block, else_block) => {
310                    let if_compiler = IfCompiler {
311                        compiler: self,
312                        lhs: lhs.fp(),
313                        rhs: ValueOrConst::Val(rhs.fp()),
314                        is_eq: false,
315                    };
316                    if else_block.is_empty() {
317                        if_compiler.then(|builder| builder.build(then_block), debug_info);
318                    } else {
319                        if_compiler.then_or_else(
320                            |builder| builder.build(then_block),
321                            |builder| builder.build(else_block),
322                            debug_info,
323                        );
324                    }
325                }
326                DslIr::IfEqI(lhs, rhs, then_block, else_block) => {
327                    let if_compiler = IfCompiler {
328                        compiler: self,
329                        lhs: lhs.fp(),
330                        rhs: ValueOrConst::Const(rhs),
331                        is_eq: true,
332                    };
333                    if else_block.is_empty() {
334                        if_compiler.then(|builder| builder.build(then_block), debug_info);
335                    } else {
336                        if_compiler.then_or_else(
337                            |builder| builder.build(then_block),
338                            |builder| builder.build(else_block),
339                            debug_info,
340                        );
341                    }
342                }
343                DslIr::IfNeI(lhs, rhs, then_block, else_block) => {
344                    let if_compiler = IfCompiler {
345                        compiler: self,
346                        lhs: lhs.fp(),
347                        rhs: ValueOrConst::Const(rhs),
348                        is_eq: false,
349                    };
350                    if else_block.is_empty() {
351                        if_compiler.then(|builder| builder.build(then_block), debug_info);
352                    } else {
353                        if_compiler.then_or_else(
354                            |builder| builder.build(then_block),
355                            |builder| builder.build(else_block),
356                            debug_info,
357                        );
358                    }
359                }
360                DslIr::ZipFor(starts, end0, step_sizes, loop_vars, block) => {
361                    let zip_for_compiler = ZipForCompiler {
362                        compiler: self,
363                        starts,
364                        end0,
365                        step_sizes,
366                        loop_vars,
367                    };
368                    zip_for_compiler.for_each(move |_, builder| builder.build(block), debug_info);
369                }
370                DslIr::AssertEqV(lhs, rhs) => {
371                    // If lhs != rhs, execute TRAP
372                    self.assert(lhs.fp(), ValueOrConst::Val(rhs.fp()), false, debug_info)
373                }
374                DslIr::AssertEqVI(lhs, rhs) => {
375                    // If lhs != rhs, execute TRAP
376                    self.assert(lhs.fp(), ValueOrConst::Const(rhs), false, debug_info)
377                }
378                DslIr::AssertEqF(lhs, rhs) => {
379                    // If lhs != rhs, execute TRAP
380                    self.assert(lhs.fp(), ValueOrConst::Val(rhs.fp()), false, debug_info)
381                }
382                DslIr::AssertEqFI(lhs, rhs) => {
383                    // If lhs != rhs, execute TRAP
384                    self.assert(lhs.fp(), ValueOrConst::Const(rhs), false, debug_info)
385                }
386                DslIr::AssertEqE(lhs, rhs) => {
387                    // If lhs != rhs, execute TRAP
388                    self.assert(lhs.fp(), ValueOrConst::ExtVal(rhs.fp()), false, debug_info)
389                }
390                DslIr::AssertEqEI(lhs, rhs) => {
391                    // If lhs != rhs, execute TRAP
392                    self.assert(lhs.fp(), ValueOrConst::ExtConst(rhs), false, debug_info)
393                }
394                DslIr::AssertNonZero(u) => {
395                    // If u == 0, execute TRAP
396                    match u {
397                        Usize::Const(_) => self.assert(
398                            u.value() as i32,
399                            ValueOrConst::Const(F::ZERO),
400                            true,
401                            debug_info,
402                        ),
403                        Usize::Var(v) => {
404                            self.assert(v.fp(), ValueOrConst::Const(F::ZERO), true, debug_info)
405                        }
406                    }
407                }
408                DslIr::Alloc(ptr, len, size) => {
409                    self.alloc(ptr, len, size, debug_info);
410                }
411                DslIr::LoadV(var, ptr, index) => match index.fp() {
412                    IndexTriple::Const(index, offset, size) => self.push(
413                        AsmInstruction::LoadFI(var.fp(), ptr.fp(), index, size, offset),
414                        debug_info.clone(),
415                    ),
416                    IndexTriple::Var(index, offset, size) => {
417                        self.add_scaled(A0, ptr.fp(), index, size, debug_info.clone());
418                        self.push(
419                            AsmInstruction::LoadFI(var.fp(), A0, F::ZERO, F::ZERO, offset),
420                            debug_info.clone(),
421                        )
422                    }
423                },
424                DslIr::LoadF(var, ptr, index) => match index.fp() {
425                    IndexTriple::Const(index, offset, size) => self.push(
426                        AsmInstruction::LoadFI(var.fp(), ptr.fp(), index, size, offset),
427                        debug_info.clone(),
428                    ),
429                    IndexTriple::Var(index, offset, size) => {
430                        self.add_scaled(A0, ptr.fp(), index, size, debug_info.clone());
431                        self.push(
432                            AsmInstruction::LoadFI(var.fp(), A0, F::ZERO, F::ZERO, offset),
433                            debug_info.clone(),
434                        )
435                    }
436                },
437                DslIr::LoadE(var, ptr, index) => match index.fp() {
438                    IndexTriple::Const(index, offset, size) => {
439                        self.load_ext(var, ptr.fp(), index * size + offset, debug_info)
440                    }
441                    IndexTriple::Var(index, offset, size) => {
442                        self.add_scaled(A0, ptr.fp(), index, size, debug_info.clone());
443                        self.load_ext(var, A0, offset, debug_info)
444                    }
445                },
446                DslIr::LoadHeapPtr(ptr) => self.push(
447                    AsmInstruction::AddFI(ptr.fp(), HEAP_PTR, F::ZERO),
448                    debug_info,
449                ),
450                DslIr::StoreV(var, ptr, index) => match index.fp() {
451                    IndexTriple::Const(index, offset, size) => self.push(
452                        AsmInstruction::StoreFI(var.fp(), ptr.fp(), index, size, offset),
453                        debug_info.clone(),
454                    ),
455                    IndexTriple::Var(index, offset, size) => {
456                        self.add_scaled(A0, ptr.fp(), index, size, debug_info.clone());
457                        self.push(
458                            AsmInstruction::StoreFI(var.fp(), A0, F::ZERO, F::ZERO, offset),
459                            debug_info.clone(),
460                        )
461                    }
462                },
463                DslIr::StoreF(var, ptr, index) => match index.fp() {
464                    IndexTriple::Const(index, offset, size) => self.push(
465                        AsmInstruction::StoreFI(var.fp(), ptr.fp(), index, size, offset),
466                        debug_info.clone(),
467                    ),
468                    IndexTriple::Var(index, offset, size) => {
469                        self.add_scaled(A0, ptr.fp(), index, size, debug_info.clone());
470                        self.push(
471                            AsmInstruction::StoreFI(var.fp(), A0, F::ZERO, F::ZERO, offset),
472                            debug_info.clone(),
473                        )
474                    }
475                },
476                DslIr::StoreE(var, ptr, index) => match index.fp() {
477                    IndexTriple::Const(index, offset, size) => {
478                        self.store_ext(var, ptr.fp(), index * size + offset, debug_info)
479                    }
480                    IndexTriple::Var(index, offset, size) => {
481                        self.add_scaled(A0, ptr.fp(), index, size, debug_info.clone());
482                        self.store_ext(var, A0, offset, debug_info)
483                    }
484                },
485                DslIr::StoreHeapPtr(ptr) => self.push(
486                    AsmInstruction::AddFI(HEAP_PTR, ptr.fp(), F::ZERO),
487                    debug_info,
488                ),
489                DslIr::HintBitsF(var, len) => {
490                    self.push(AsmInstruction::HintBits(var.fp(), len), debug_info);
491                }
492                DslIr::Poseidon2PermuteBabyBear(dst, src) => match (dst, src) {
493                    (Array::Dyn(dst, _), Array::Dyn(src, _)) => self.push(
494                        AsmInstruction::Poseidon2Permute(dst.fp(), src.fp()),
495                        debug_info,
496                    ),
497                    _ => unimplemented!(),
498                },
499                DslIr::Poseidon2CompressBabyBear(result, left, right) => {
500                    match (result, left, right) {
501                        (Array::Dyn(result, _), Array::Dyn(left, _), Array::Dyn(right, _)) => self
502                            .push(
503                                AsmInstruction::Poseidon2Compress(
504                                    result.fp(),
505                                    left.fp(),
506                                    right.fp(),
507                                ),
508                                debug_info,
509                            ),
510                        _ => unimplemented!(),
511                    }
512                }
513                DslIr::Error() => self.push(AsmInstruction::j(self.trap_label), debug_info),
514                DslIr::PrintF(dst) => {
515                    self.push(AsmInstruction::PrintF(dst.fp()), debug_info);
516                }
517                DslIr::PrintV(dst) => {
518                    self.push(AsmInstruction::PrintV(dst.fp()), debug_info);
519                }
520                DslIr::PrintE(dst) => {
521                    self.push(AsmInstruction::PrintE(dst.fp()), debug_info);
522                }
523                DslIr::HintInputVec() => {
524                    self.push(AsmInstruction::HintInputVec(), debug_info);
525                }
526                DslIr::HintFelt() => {
527                    self.push(AsmInstruction::HintFelt(), debug_info);
528                }
529                DslIr::StoreHintWord(ptr, index) => match index.fp() {
530                    IndexTriple::Const(index, offset, size) => self.push(
531                        AsmInstruction::StoreHintWordI(ptr.fp(), size * index + offset),
532                        debug_info.clone(),
533                    ),
534                    IndexTriple::Var(index, offset, size) => {
535                        self.add_scaled(A0, ptr.fp(), index, size, debug_info.clone());
536                        self.push(AsmInstruction::StoreHintWordI(A0, offset), debug_info)
537                    }
538                },
539                DslIr::HintLoad() => {
540                    self.push(AsmInstruction::HintLoad(), debug_info);
541                }
542                DslIr::Publish(val, index) => {
543                    self.push(AsmInstruction::Publish(val.fp(), index.fp()), debug_info);
544                }
545                DslIr::CycleTrackerStart(name) => {
546                    self.push(
547                        AsmInstruction::CycleTrackerStart(),
548                        Some(DebugInfo {
549                            dsl_instruction: format!("CT-{}", name),
550                            trace: None,
551                        }),
552                    );
553                }
554                DslIr::CycleTrackerEnd(name) => {
555                    self.push(
556                        AsmInstruction::CycleTrackerEnd(),
557                        Some(DebugInfo {
558                            dsl_instruction: format!("CT-{}", name),
559                            trace: None,
560                        }),
561                    );
562                }
563                DslIr::Halt => {
564                    self.push(AsmInstruction::Halt, debug_info);
565                }
566                DslIr::FriReducedOpening(
567                    alpha,
568                    hint_id,
569                    is_init,
570                    at_x_array,
571                    at_z_array,
572                    result,
573                ) => {
574                    self.push(
575                        AsmInstruction::FriReducedOpening(
576                            at_x_array.ptr().fp(),
577                            at_z_array.ptr().fp(),
578                            at_z_array.len().get_var().fp(),
579                            alpha.fp(),
580                            result.fp(),
581                            hint_id.fp(),
582                            is_init.fp(),
583                        ),
584                        debug_info,
585                    );
586                }
587                DslIr::VerifyBatchFelt(dim, opened, proof_id, index, commit) => {
588                    self.push(
589                        AsmInstruction::VerifyBatchFelt(
590                            dim.ptr().fp(),
591                            opened.ptr().fp(),
592                            opened.len().get_var().fp(),
593                            proof_id.fp(),
594                            index.ptr().fp(),
595                            commit.ptr().fp(),
596                        ),
597                        debug_info,
598                    );
599                }
600                DslIr::VerifyBatchExt(dim, opened, proof_id, index, commit) => {
601                    self.push(
602                        AsmInstruction::VerifyBatchExt(
603                            dim.ptr().fp(),
604                            opened.ptr().fp(),
605                            opened.len().get_var().fp(),
606                            proof_id.fp(),
607                            index.ptr().fp(),
608                            commit.ptr().fp(),
609                        ),
610                        debug_info,
611                    );
612                }
613                DslIr::RangeCheckV(v, num_bits) => {
614                    let (lo_bits, hi_bits) = lo_hi_bits(num_bits as u32);
615                    self.push(
616                        AsmInstruction::RangeCheck(v.fp(), lo_bits, hi_bits),
617                        debug_info,
618                    );
619                }
620                _ => unimplemented!(),
621            }
622        }
623    }
624
625    pub fn alloc(
626        &mut self,
627        ptr: Ptr<F>,
628        len: impl Into<RVar<F>>,
629        size: usize,
630        debug_info: Option<DebugInfo>,
631    ) {
632        let word_size = self.word_size;
633        let align = |x: usize| x.div_ceil(word_size) * word_size;
634        // Load the current heap ptr address to the stack value and advance the heap ptr.
635        let len = len.into();
636        match len {
637            RVar::Const(len) => {
638                self.push(
639                    AsmInstruction::CopyF(ptr.fp(), HEAP_PTR),
640                    debug_info.clone(),
641                );
642                let inc = align((len.as_canonical_u32() as usize) * size);
643                assert!((inc as u32) < HEAP_SIZE, "Allocation size too large");
644                let inc_f = F::from_canonical_usize(inc);
645                self.push(
646                    AsmInstruction::AddFI(HEAP_PTR, HEAP_PTR, inc_f),
647                    debug_info.clone(),
648                );
649                let (lo_bits, hi_bits) = lo_hi_bits(MEMORY_BITS as u32);
650                self.push(
651                    AsmInstruction::RangeCheck(HEAP_PTR, lo_bits, hi_bits),
652                    debug_info,
653                );
654            }
655            RVar::Val(len) => {
656                self.push(
657                    AsmInstruction::CopyF(ptr.fp(), HEAP_PTR),
658                    debug_info.clone(),
659                );
660                let size: usize = align(size);
661                // Avoid (len * size) overflowing
662                {
663                    let size_bit = usize::BITS - size.leading_zeros();
664                    assert!(MEMORY_BITS as u32 > size_bit);
665                    let (lo_bits, hi_bits) = lo_hi_bits(MEMORY_BITS as u32 - size_bit);
666                    self.push(
667                        AsmInstruction::RangeCheck(len.fp(), lo_bits, hi_bits),
668                        debug_info.clone(),
669                    );
670                }
671                let size_f = F::from_canonical_usize(size);
672                self.push(
673                    AsmInstruction::MulFI(A0, len.fp(), size_f),
674                    debug_info.clone(),
675                );
676                self.push(
677                    AsmInstruction::AddF(HEAP_PTR, HEAP_PTR, A0),
678                    debug_info.clone(),
679                );
680                let (lo_bits, hi_bits) = lo_hi_bits(MEMORY_BITS as u32);
681                // Avoid HEAP_PTR overflowing
682                self.push(
683                    AsmInstruction::RangeCheck(HEAP_PTR, lo_bits, hi_bits),
684                    debug_info,
685                );
686            }
687        }
688    }
689
690    pub fn assert(
691        &mut self,
692        lhs: i32,
693        rhs: ValueOrConst<F, EF>,
694        is_eq: bool,
695        debug_info: Option<DebugInfo>,
696    ) {
697        let trap_label = self.trap_label;
698        let if_compiler = IfCompiler {
699            compiler: self,
700            lhs,
701            rhs,
702            is_eq: !is_eq,
703        };
704        if_compiler.then_label(trap_label, debug_info);
705    }
706
707    pub fn code(self) -> AssemblyCode<F, EF> {
708        let labels = self
709            .function_labels
710            .into_iter()
711            .map(|(k, v)| (v, k))
712            .collect();
713        AssemblyCode::new(self.basic_blocks, labels)
714    }
715
716    fn basic_block(&mut self) {
717        self.basic_blocks.push(BasicBlock::new());
718    }
719
720    fn block_label(&mut self) -> F {
721        F::from_canonical_usize(self.basic_blocks.len() - 1)
722    }
723
724    fn push_to_block(
725        &mut self,
726        block_label: F,
727        instruction: AsmInstruction<F, EF>,
728        debug_info: Option<DebugInfo>,
729    ) {
730        self.basic_blocks
731            .get_mut(block_label.as_canonical_u32() as usize)
732            .unwrap_or_else(|| panic!("Missing block at label: {:?}", block_label))
733            .push(instruction, debug_info);
734    }
735
736    fn push(&mut self, instruction: AsmInstruction<F, EF>, debug_info: Option<DebugInfo>) {
737        self.basic_blocks
738            .last_mut()
739            .unwrap()
740            .push(instruction, debug_info);
741    }
742
743    // mem[dst] <- mem[src] + c * mem[val]
744    // assumes dst != src
745    fn add_scaled(&mut self, dst: i32, src: i32, val: i32, c: F, debug_info: Option<DebugInfo>) {
746        if c == F::ONE {
747            self.push(AsmInstruction::AddF(dst, src, val), debug_info);
748        } else {
749            self.push(AsmInstruction::MulFI(dst, val, c), debug_info.clone());
750            self.push(AsmInstruction::AddF(dst, dst, src), debug_info);
751        }
752    }
753}
754
755pub struct IfCompiler<'a, F, EF> {
756    compiler: &'a mut AsmCompiler<F, EF>,
757    lhs: i32,
758    rhs: ValueOrConst<F, EF>,
759    is_eq: bool,
760}
761
762impl<F: PrimeField32 + TwoAdicField, EF: ExtensionField<F> + TwoAdicField> IfCompiler<'_, F, EF> {
763    pub fn then<Func>(self, f: Func, debug_info: Option<DebugInfo>)
764    where
765        Func: FnOnce(&mut AsmCompiler<F, EF>),
766    {
767        let Self {
768            compiler,
769            lhs,
770            rhs,
771            is_eq,
772        } = self;
773
774        // Get the label for the current block.
775        let current_block = compiler.block_label();
776
777        // Generate the blocks for the then branch.
778        compiler.basic_block();
779        f(compiler);
780
781        // Generate the block for returning to the main flow.
782        compiler.basic_block();
783        let after_if_block = compiler.block_label();
784
785        // Get the branch instruction to push to the `current_block`.
786        let instr = Self::branch(lhs, rhs, is_eq, after_if_block);
787        compiler.push_to_block(current_block, instr, debug_info);
788    }
789
790    pub fn then_label(self, label: F, debug_info: Option<DebugInfo>) {
791        let Self {
792            compiler,
793            lhs,
794            rhs,
795            is_eq,
796        } = self;
797
798        // Get the label for the current block.
799        let current_block = compiler.block_label();
800
801        // Get the branch instruction to push to the `current_block`.
802        let instr = Self::branch(lhs, rhs, is_eq, label);
803        compiler.push_to_block(current_block, instr, debug_info);
804    }
805
806    pub fn then_or_else<ThenFunc, ElseFunc>(
807        self,
808        then_f: ThenFunc,
809        else_f: ElseFunc,
810        debug_info: Option<DebugInfo>,
811    ) where
812        ThenFunc: FnOnce(&mut AsmCompiler<F, EF>),
813        ElseFunc: FnOnce(&mut AsmCompiler<F, EF>),
814    {
815        let Self {
816            compiler,
817            lhs,
818            rhs,
819            is_eq,
820        } = self;
821
822        // Get the label for the current block, so we can generate the jump instruction into it.
823        // conditional branch instruction to it, if the condition is not met.
824        let if_branching_block = compiler.block_label();
825
826        // Generate the block for the then branch.
827        compiler.basic_block();
828        then_f(compiler);
829        let last_if_block = compiler.block_label();
830
831        // Generate the block for the else branch.
832        compiler.basic_block();
833        let else_block = compiler.block_label();
834        else_f(compiler);
835
836        // Generate the jump instruction to the else block
837        let instr = Self::branch(lhs, rhs, is_eq, else_block);
838        compiler.push_to_block(if_branching_block, instr, debug_info.clone());
839
840        // Generate the block for returning to the main flow.
841        compiler.basic_block();
842        let main_flow_block = compiler.block_label();
843        let instr = AsmInstruction::j(main_flow_block);
844        compiler.push_to_block(last_if_block, instr, debug_info.clone());
845    }
846
847    const fn branch(
848        lhs: i32,
849        rhs: ValueOrConst<F, EF>,
850        is_eq: bool,
851        block: F,
852    ) -> AsmInstruction<F, EF> {
853        match (rhs, is_eq) {
854            (ValueOrConst::Const(rhs), true) => AsmInstruction::BneI(block, lhs, rhs),
855            (ValueOrConst::Const(rhs), false) => AsmInstruction::BeqI(block, lhs, rhs),
856            (ValueOrConst::ExtConst(rhs), true) => AsmInstruction::BneEI(block, lhs, rhs),
857            (ValueOrConst::ExtConst(rhs), false) => AsmInstruction::BeqEI(block, lhs, rhs),
858            (ValueOrConst::Val(rhs), true) => AsmInstruction::Bne(block, lhs, rhs),
859            (ValueOrConst::Val(rhs), false) => AsmInstruction::Beq(block, lhs, rhs),
860            (ValueOrConst::ExtVal(rhs), true) => AsmInstruction::BneE(block, lhs, rhs),
861            (ValueOrConst::ExtVal(rhs), false) => AsmInstruction::BeqE(block, lhs, rhs),
862        }
863    }
864}
865
866// Zipped for loop -- loop extends over the first entry in starts and end0
867// ATTENTION: starting with starts[0] > end0 will lead to undefined behavior.
868pub struct ZipForCompiler<'a, F: Field, EF> {
869    compiler: &'a mut AsmCompiler<F, EF>,
870    starts: Vec<RVar<F>>,
871    end0: RVar<F>,
872    step_sizes: Vec<F>,
873    loop_vars: Vec<Var<F>>,
874}
875
876impl<F: PrimeField32 + TwoAdicField, EF: ExtensionField<F> + TwoAdicField>
877    ZipForCompiler<'_, F, EF>
878{
879    /// This assumes that the number of steps in `range(starts[0], ends[0], step_sizes[0])` is
880    /// minimal among all ranges.
881    ///
882    /// It is the responsibility of the caller to ensure that this precondition holds.
883    pub(super) fn for_each(
884        self,
885        f: impl FnOnce(Vec<Var<F>>, &mut AsmCompiler<F, EF>),
886        debug_info: Option<DebugInfo>,
887    ) {
888        // initialize the loop variables
889        self.starts
890            .iter()
891            .zip(self.loop_vars.iter())
892            .for_each(|(start, loop_var)| match start {
893                RVar::Const(start) => {
894                    self.compiler.push(
895                        AsmInstruction::ImmF(loop_var.fp(), *start),
896                        debug_info.clone(),
897                    );
898                }
899                RVar::Val(start) => {
900                    self.compiler.push(
901                        AsmInstruction::CopyF(loop_var.fp(), start.fp()),
902                        debug_info.clone(),
903                    );
904                }
905            });
906
907        let loop_call_label = self.compiler.block_label();
908
909        self.compiler.basic_block();
910        let loop_label = self.compiler.block_label();
911
912        f(self.loop_vars.clone(), self.compiler);
913
914        self.loop_vars
915            .iter()
916            .zip(self.step_sizes.iter())
917            .for_each(|(loop_var, step_size)| {
918                self.compiler.push(
919                    AsmInstruction::AddFI(loop_var.fp(), loop_var.fp(), *step_size),
920                    debug_info.clone(),
921                );
922            });
923
924        self.compiler.basic_block();
925        let end = self.end0;
926        let loop_var = self.loop_vars[0];
927        match end {
928            RVar::Const(end) => {
929                self.compiler.push(
930                    AsmInstruction::BneI(loop_label, loop_var.fp(), end),
931                    debug_info.clone(),
932                );
933            }
934            RVar::Val(end) => {
935                self.compiler.push(
936                    AsmInstruction::Bne(loop_label, loop_var.fp(), end.fp()),
937                    debug_info.clone(),
938                );
939            }
940        };
941
942        let label = self.compiler.block_label();
943        let instr = AsmInstruction::j(label);
944        self.compiler
945            .push_to_block(loop_call_label, instr, debug_info.clone());
946
947        self.compiler.basic_block();
948    }
949}
950
951// Ext compiler logic.
952impl<F: PrimeField32 + TwoAdicField, EF: ExtensionField<F> + TwoAdicField> AsmCompiler<F, EF> {
953    fn assign_exti(&mut self, dst: i32, imm: EF, debug_info: Option<DebugInfo>) {
954        let imm = imm.as_base_slice();
955        for i in 0..EF::D {
956            self.push(
957                AsmInstruction::ImmF(dst + i as i32, imm[i]),
958                debug_info.clone(),
959            );
960        }
961    }
962
963    fn load_ext(&mut self, val: Ext<F, EF>, addr: i32, offset: F, debug_info: Option<DebugInfo>) {
964        self.push(
965            AsmInstruction::LoadEI(val.fp(), addr, F::ZERO, F::ONE, offset),
966            debug_info.clone(),
967        );
968    }
969
970    fn store_ext(&mut self, val: Ext<F, EF>, addr: i32, offset: F, debug_info: Option<DebugInfo>) {
971        self.push(
972            AsmInstruction::StoreEI(val.fp(), addr, F::ZERO, F::ONE, offset),
973            debug_info.clone(),
974        );
975    }
976
977    fn add_ext_exti(
978        &mut self,
979        dst: Ext<F, EF>,
980        lhs: Ext<F, EF>,
981        rhs: EF,
982        debug_info: Option<DebugInfo>,
983    ) {
984        let rhs = rhs.as_base_slice();
985        for i in 0..EF::D {
986            let j = i as i32;
987            self.push(
988                AsmInstruction::AddFI(dst.fp() + j, lhs.fp() + j, rhs[i]),
989                debug_info.clone(),
990            );
991        }
992    }
993
994    fn sub_exti_ext(
995        &mut self,
996        dst: Ext<F, EF>,
997        lhs: EF,
998        rhs: Ext<F, EF>,
999        debug_info: Option<DebugInfo>,
1000    ) {
1001        let lhs = lhs.as_base_slice();
1002        for i in 0..EF::D {
1003            let j = i as i32;
1004            self.push(
1005                AsmInstruction::SubFIN(dst.fp() + j, lhs[i], rhs.fp() + j),
1006                debug_info.clone(),
1007            );
1008        }
1009    }
1010
1011    fn add_ext_felt(
1012        &mut self,
1013        dst: Ext<F, EF>,
1014        lhs: Ext<F, EF>,
1015        rhs: Felt<F>,
1016        debug_info: Option<DebugInfo>,
1017    ) {
1018        self.push(
1019            AsmInstruction::AddF(dst.fp(), lhs.fp(), rhs.fp()),
1020            debug_info.clone(),
1021        );
1022        for i in 1..EF::D {
1023            let j = i as i32;
1024            self.push(
1025                AsmInstruction::CopyF(dst.fp() + j, lhs.fp() + j),
1026                debug_info.clone(),
1027            );
1028        }
1029    }
1030
1031    fn sub_ext_felt(
1032        &mut self,
1033        dst: Ext<F, EF>,
1034        lhs: Ext<F, EF>,
1035        rhs: Felt<F>,
1036        debug_info: Option<DebugInfo>,
1037    ) {
1038        self.push(
1039            AsmInstruction::SubF(dst.fp(), lhs.fp(), rhs.fp()),
1040            debug_info.clone(),
1041        );
1042        for i in 1..EF::D {
1043            let j = i as i32;
1044            self.push(
1045                AsmInstruction::CopyF(dst.fp() + j, lhs.fp() + j),
1046                debug_info.clone(),
1047            );
1048        }
1049    }
1050
1051    fn add_felt_exti(
1052        &mut self,
1053        dst: Ext<F, EF>,
1054        lhs: Felt<F>,
1055        rhs: EF,
1056        debug_info: Option<DebugInfo>,
1057    ) {
1058        let rhs = rhs.as_base_slice();
1059
1060        self.push(
1061            AsmInstruction::CopyF(dst.fp(), lhs.fp()),
1062            debug_info.clone(),
1063        );
1064
1065        for i in 1..EF::D {
1066            let j = i as i32;
1067            self.push(
1068                AsmInstruction::ImmF(dst.fp() + j, rhs[i]),
1069                debug_info.clone(),
1070            );
1071        }
1072    }
1073
1074    fn mul_ext_felt(
1075        &mut self,
1076        dst: Ext<F, EF>,
1077        lhs: Ext<F, EF>,
1078        rhs: Felt<F>,
1079        debug_info: Option<DebugInfo>,
1080    ) {
1081        for i in 0..EF::D {
1082            let j = i as i32;
1083            self.push(
1084                AsmInstruction::MulF(dst.fp() + j, lhs.fp() + j, rhs.fp()),
1085                debug_info.clone(),
1086            );
1087        }
1088    }
1089
1090    fn mul_ext_felti(
1091        &mut self,
1092        dst: Ext<F, EF>,
1093        lhs: Ext<F, EF>,
1094        rhs: F,
1095        debug_info: Option<DebugInfo>,
1096    ) {
1097        for i in 0..EF::D {
1098            let j = i as i32;
1099            self.push(
1100                AsmInstruction::MulFI(dst.fp() + j, lhs.fp() + j, rhs),
1101                debug_info.clone(),
1102            );
1103        }
1104    }
1105
1106    fn div_ext_felt(
1107        &mut self,
1108        dst: Ext<F, EF>,
1109        lhs: Ext<F, EF>,
1110        rhs: Felt<F>,
1111        debug_info: Option<DebugInfo>,
1112    ) {
1113        for i in 0..EF::D {
1114            let j = i as i32;
1115            self.push(
1116                AsmInstruction::DivF(dst.fp() + j, lhs.fp() + j, rhs.fp()),
1117                debug_info.clone(),
1118            );
1119        }
1120    }
1121}
1122
1123fn lo_hi_bits(bits: u32) -> (i32, i32) {
1124    let lo_bits = bits.min(16);
1125    let hi_bits = bits.max(16) - 16;
1126    (lo_bits as i32, hi_bits as i32)
1127}