wahgex_core/compile/
transition.rs

1//! This module contains type and functions related to the NFA transition
2//! function.
3
4use std::{
5    alloc::{Layout, LayoutError},
6    collections::HashMap,
7    mem,
8};
9
10use regex_automata::{
11    nfa::thompson::{DenseTransitions, SparseTransitions, State, Transition, NFA},
12    util::primitives::StateID,
13};
14use wasm_encoder::{BlockType, InstructionSink, MemArg, NameMap, ValType};
15
16use crate::util::repeat;
17
18use super::{
19    context::{
20        ActiveDataSegment, BlockSignature, Function, FunctionDefinition, FunctionIdx,
21        FunctionSignature, TypeIdx,
22    },
23    epsilon_closure::EpsilonClosureFunctions,
24    instructions::InstructionSinkExt,
25    CompileContext,
26};
27
28const SPARSE_RANGE_LOOKUP_TABLE_ELEM: Layout = const {
29    match Layout::array::<u8>(2) {
30        Ok(val) => val,
31        Err(_) => panic!("invalid layout"),
32    }
33};
34
35/// This struct contains the layout for the lookup tables used by the
36/// [`TransitionFunctions`].
37///
38/// Each transition function may have a differently shaped lookup table.
39///  - For [`Transition`]s, there is no lookup table and we just embed the
40///    `start` and `end` directly into the function.
41///  - For [`SparseTransitions`]s, the table is represented as 2 arrays. The
42///    first array the `start` `end` tuples from the [`Transition`]s. The second
43///    array contains the [`StateID`]s arranged to match the same order as the
44///    tuples.
45///  - For [`DenseTransitions`], it will be an array of length 256 containing
46///    [`StateID`]s.
47#[derive(Debug)]
48pub struct TransitionLayout {
49    lookup_tables: HashMap<StateID, LookupTable>,
50}
51
52/// This enum represents the different type of lookup tables and their offsets.
53///
54/// See [`TransitionLayout`] for more details.
55#[derive(Debug, Clone, Copy)]
56enum LookupTable {
57    Sparse(SparseTable),
58    Dense(DenseTable),
59}
60
61impl LookupTable {
62    fn unwrap_sparse(self) -> SparseTable {
63        match self {
64            LookupTable::Sparse(table) => table,
65            _ => panic!("not a sparse table offset"),
66        }
67    }
68
69    fn unwrap_dense(self) -> DenseTable {
70        match self {
71            LookupTable::Dense(table) => table,
72            _ => panic!("not a dense table offset"),
73        }
74    }
75}
76
77#[derive(Debug, Clone, Copy)]
78struct SparseTable {
79    range_table_pos: usize,
80    range_table_len: usize,
81    range_lookup_table_stride: usize,
82    state_id_table_pos: usize,
83    state_id_table_stride: usize,
84}
85
86#[derive(Debug, Clone, Copy)]
87struct DenseTable {
88    table_pos: usize,
89    table_stride: usize,
90}
91
92impl TransitionLayout {
93    /// Creates a new `TransitionLayout` by calculating the memory offsets for
94    /// transition lookup tables.
95    pub fn new(
96        ctx: &mut CompileContext,
97        mut overall: Layout,
98    ) -> Result<(Layout, Self), LayoutError> {
99        let mut lookup_table_offsets = HashMap::new();
100
101        let states = ctx.nfa.states();
102        let state_id_layout = *ctx.state_id_layout();
103        for for_sid in (0..states.len()).map(StateID::new).map(Result::unwrap) {
104            let state = &states[for_sid.as_usize()];
105            match state {
106                State::ByteRange { .. } => {
107                    // no lookup table
108                },
109                State::Sparse(SparseTransitions { transitions }) => {
110                    let (range_data, state_data) =
111                        flatten_sparse_transition(transitions, &state_id_layout);
112
113                    // (start, end) tuples arranged together
114                    let (range_lookup_table, range_lookup_table_stride) =
115                        repeat(&SPARSE_RANGE_LOOKUP_TABLE_ELEM, transitions.len())?;
116                    let (new_overall, range_table_pos) = overall.extend(range_lookup_table)?;
117                    overall = new_overall;
118
119                    ctx.sections.add_active_data_segment(ActiveDataSegment {
120                        name: format!("sparse_range_table_{}", for_sid.as_u32()),
121                        position: range_table_pos,
122                        data: range_data,
123                    });
124
125                    // state IDs all packed together
126                    let (state_id_table, state_id_table_stride) =
127                        repeat(&state_id_layout, transitions.len())?;
128                    let (new_overall, state_id_table_pos) = overall.extend(state_id_table)?;
129                    overall = new_overall;
130
131                    ctx.sections.add_active_data_segment(ActiveDataSegment {
132                        name: format!("sparse_state_id_table_{}", for_sid.as_u32()),
133                        position: state_id_table_pos,
134                        data: state_data,
135                    });
136
137                    lookup_table_offsets.insert(
138                        for_sid,
139                        LookupTable::Sparse(SparseTable {
140                            range_table_pos,
141                            range_lookup_table_stride,
142                            range_table_len: transitions.len(),
143                            state_id_table_pos,
144                            state_id_table_stride,
145                        }),
146                    );
147                },
148                State::Dense(DenseTransitions { transitions }) => {
149                    let (lookup_table_layout, table_stride) = repeat(&state_id_layout, 256)?;
150                    let (new_overall, table_pos) = overall.extend(lookup_table_layout)?;
151                    overall = new_overall;
152                    lookup_table_offsets.insert(
153                        for_sid,
154                        LookupTable::Dense(DenseTable {
155                            table_pos,
156                            table_stride,
157                        }),
158                    );
159
160                    ctx.sections.add_active_data_segment(ActiveDataSegment {
161                        name: format!("dense_table_{}", for_sid.as_u32()),
162                        position: table_pos,
163                        data: flatten_dense_transition(transitions, &state_id_layout),
164                    });
165                },
166                _ => {
167                    // no lookup table
168                },
169            }
170        }
171
172        Ok((
173            overall,
174            Self {
175                lookup_tables: lookup_table_offsets,
176            },
177        ))
178    }
179
180    fn get(&self, sid: StateID) -> Option<LookupTable> {
181        self.lookup_tables.get(&sid).copied()
182    }
183}
184
185fn flatten_sparse_transition(
186    sparse: &[Transition],
187    state_id_layout: &Layout,
188) -> (Vec<u8>, Vec<u8>) {
189    let mut range_output = Vec::with_capacity(mem::size_of::<u8>() * sparse.len() * 2);
190
191    for transition in sparse {
192        range_output.push(transition.start);
193        range_output.push(transition.end);
194    }
195
196    let mut state_output = Vec::with_capacity(state_id_layout.size() * sparse.len());
197
198    for transition in sparse {
199        // WASM assumes little endian byte ordering: https://webassembly.org/docs/portability/
200        let bytes = transition.next.as_u32().to_le_bytes();
201        state_output.extend_from_slice(&bytes[..state_id_layout.size()]);
202    }
203
204    (range_output, state_output)
205}
206
207fn flatten_dense_transition(dense: &[StateID], state_id_layout: &Layout) -> Vec<u8> {
208    assert_eq!(dense.len(), 256);
209
210    let mut output = Vec::with_capacity(state_id_layout.size() * dense.len());
211
212    for state_id in dense {
213        // WASM assumes little endian byte ordering: https://webassembly.org/docs/portability/
214        output.extend_from_slice(&state_id.as_u32().to_le_bytes());
215    }
216
217    output
218}
219
220/// This struct contains a map of functions that are the transitions
221/// for each NFA state.
222///
223/// This corresponds to the `next` function in
224/// [`PikeVM`][regex_automata::nfa::thompson::pikevm::PikeVM].
225#[derive(Debug)]
226pub struct TransitionFunctions {
227    #[expect(dead_code)]
228    state_transitions: HashMap<StateID, FunctionIdx>,
229    #[expect(dead_code)]
230    branch_to_transition: FunctionIdx,
231    pub make_current_transitions: FunctionIdx,
232}
233
234impl TransitionFunctions {
235    /// Creates and registers all WebAssembly functions required for handling
236    /// NFA state transitions.
237    pub fn new(
238        ctx: &mut CompileContext,
239        epsilon_closures: &EpsilonClosureFunctions,
240        transition_layout: &TransitionLayout,
241    ) -> Self {
242        // NOTE: The indexes of the `states` array correspond to the `StateID` value.
243        let mut state_transitions = HashMap::new();
244
245        let states = ctx.nfa.states();
246        for for_sid in (0..states.len()).map(StateID::new).map(Result::unwrap) {
247            if !Self::needs_transition_fn(&ctx.nfa, for_sid) {
248                continue;
249            }
250
251            let transition_fn = Self::transition_fn(
252                for_sid,
253                ctx.nfa.states(),
254                epsilon_closures.branch_to_epsilon_closure,
255                transition_layout.get(for_sid),
256                ctx.state_id_layout(),
257            );
258            let transition_idx = ctx.add_function(transition_fn);
259            state_transitions.insert(for_sid, transition_idx);
260        }
261
262        let branch_to_transition =
263            ctx.add_function(Self::branch_to_transition_fn(&state_transitions));
264
265        let branch_to_transition_is_match_block_sig = ctx.add_block_signature(BlockSignature {
266            name: "branch_to_transition_is_match",
267            params_ty: &[ValType::I32],
268            results_ty: &[],
269        });
270
271        let make_current_transitions = ctx.add_function(Self::make_current_transitions_fn(
272            branch_to_transition,
273            branch_to_transition_is_match_block_sig,
274            ctx.state_id_layout(),
275        ));
276
277        Self {
278            state_transitions,
279            branch_to_transition,
280            make_current_transitions,
281        }
282    }
283
284    fn make_current_transitions_fn(
285        branch_to_transition: FunctionIdx,
286        branch_to_transition_is_match_block_sig: TypeIdx,
287        state_id_layout: &Layout,
288    ) -> Function {
289        let mut locals_name_map = NameMap::new();
290        // Parameters
291        locals_name_map.append(0, "haystack_ptr");
292        locals_name_map.append(1, "haystack_len");
293        locals_name_map.append(2, "at_offset");
294        locals_name_map.append(3, "current_set_ptr");
295        locals_name_map.append(4, "current_set_len");
296        locals_name_map.append(5, "next_set_ptr");
297        locals_name_map.append(6, "next_set_len");
298        // Locals
299        locals_name_map.append(7, "loop_index");
300        locals_name_map.append(8, "state_id");
301        locals_name_map.append(9, "new_next_set_len");
302
303        let mut labels_name_map = NameMap::new();
304        labels_name_map.append(0, "set_iter_loop");
305
306        let mut body = wasm_encoder::Function::new([(3, ValType::I32)]);
307        let mut instructions = body.instructions();
308
309        // loop_index = 0 // local 7
310        // new_next_set_len = next_set_len // local 9
311        // loop {
312        //     if loop_index >= current_set_len {
313        //         return (new_next_set_len, false);
314        //     }
315        //
316        //     state_id = current_set_ptr.dense[loop_index]; // local 8
317        //     is_match, new_next_set_len = branch_to_transition(
318        //         haystack_ptr,
319        //         haystack_len,
320        //         at_offset,
321        //         next_set_ptr,
322        //         new_next_set_len,
323        //         state_id
324        //     );
325        //     if is_match {
326        //         return (new_next_set_len, true);
327        //     }
328        //
329        //     loop_index = loop_index + 1;
330        // }
331        // return (next_set_len, false); // just in case
332
333        instructions
334            // loop_index = 0 // local 7
335            .i32_const(0)
336            .local_set(7)
337            // new_next_set_len = next_set_len // local 9
338            .local_get(6)
339            .local_set(9)
340            .loop_(BlockType::Empty)
341            // if loop_index >= current_set_len {
342            .local_get(7)
343            .local_get(4)
344            .i32_ge_u()
345            .if_(BlockType::Empty)
346            // return (new_next_set_len, false);
347            .local_get(9)
348            .i32_const(false as i32)
349            .return_()
350            .end()
351            // state_id = current_set_ptr.dense[loop_index]; // local 8
352            .local_get(7)
353            .i64_extend_i32_u()
354            .i64_const(i64::from_ne_bytes(
355                u64::try_from(state_id_layout.align())
356                    .unwrap()
357                    .to_ne_bytes(),
358            ))
359            .i64_mul()
360            .local_get(3)
361            .i64_add()
362            .state_id_load(0, state_id_layout)
363            .local_set(8)
364            // is_match, new_next_set_len = branch_to_transition(..)
365            .local_get(0)
366            .local_get(1)
367            .local_get(2)
368            .local_get(5)
369            .local_get(9)
370            .local_get(8)
371            .call(branch_to_transition.into())
372            // if is_match {
373            .if_(BlockType::FunctionType(
374                branch_to_transition_is_match_block_sig.into(),
375            ))
376            // return (new_next_set_len, true);
377            .i32_const(true as i32)
378            .return_()
379            .else_()
380            .local_set(9) // need to update new_next_set_len on non-match
381            .end()
382            // loop_index = loop_index + 1;
383            .local_get(7)
384            .i32_const(1)
385            .i32_add()
386            .local_set(7)
387            .br(0) // continue loop
388            .end() // end loop
389            // return (new_next_set_len, false);
390            .local_get(9)
391            .i32_const(false as i32)
392            .end();
393
394        Function {
395            sig: FunctionSignature {
396                name: "make_current_transitions".into(),
397                // [haystack_ptr, haystack_len, at_offset, current_set_ptr, current_set_len,
398                // next_set_ptr, next_set_len]
399                params_ty: &[
400                    ValType::I64,
401                    ValType::I64,
402                    ValType::I64,
403                    ValType::I64,
404                    ValType::I32,
405                    ValType::I64,
406                    ValType::I32,
407                ],
408                // current_set is not modified by this function, so we don't return a new length
409                // [new_next_set_len, is_match]
410                results_ty: &[ValType::I32, ValType::I32],
411                export: false,
412            },
413            def: FunctionDefinition {
414                body,
415                locals_name_map,
416                labels_name_map: Some(labels_name_map),
417                branch_hints: None,
418            },
419        }
420    }
421
422    fn branch_to_transition_fn(state_transitions: &HashMap<StateID, FunctionIdx>) -> Function {
423        let mut locals_name_map = NameMap::new();
424        // Parameters
425        locals_name_map.append(0, "haystack_ptr");
426        locals_name_map.append(1, "haystack_len");
427        locals_name_map.append(2, "at_offset");
428        locals_name_map.append(3, "next_set_ptr");
429        locals_name_map.append(4, "next_set_len");
430        locals_name_map.append(5, "state_id");
431
432        let mut body = wasm_encoder::Function::new([]);
433        let mut instructions = body.instructions();
434
435        let mut states = state_transitions.keys().copied().collect::<Vec<_>>();
436        states.sort();
437
438        for sid in states {
439            let transition_fn = state_transitions.get(&sid).copied().unwrap();
440            instructions
441                .local_get(5) // state_id
442                .i32_const(i32::from_ne_bytes(sid.as_u32().to_ne_bytes()))
443                .i32_eq()
444                .if_(BlockType::Empty)
445                .local_get(0) // haystack_ptr
446                .local_get(1) // haystack_len
447                .local_get(2) // at_offset
448                .local_get(3) // next_set_ptr
449                .local_get(4) // next_set_len
450                .call(transition_fn.into())
451                .return_()
452                .end();
453        }
454
455        // If it falls through to this point, then we must assume thats its a state
456        // which has no transition. In which case, we need to return the current
457        // `next_set_len` and `false` for is_match.
458        instructions.local_get(4).i32_const(false as i32).end();
459
460        Function {
461            sig: FunctionSignature {
462                name: "branch_to_transition".into(),
463                // [haystack_ptr, haystack_len, at_offset, next_set_ptr, next_set_len, state_id]
464                params_ty: &[
465                    // TODO(opt): Remove haystack_ptr and assume that haystack always starts at
466                    // offset 0 in memory 0
467                    ValType::I64,
468                    ValType::I64,
469                    ValType::I64,
470                    ValType::I64,
471                    ValType::I32,
472                    ValType::I32,
473                ],
474                // [new_next_set_len, is_match]
475                results_ty: &[ValType::I32, ValType::I32],
476                export: false,
477            },
478            def: FunctionDefinition {
479                body,
480                locals_name_map,
481                labels_name_map: None,
482                branch_hints: None,
483            },
484        }
485    }
486
487    fn transition_fn(
488        for_sid: StateID,
489        states: &[State],
490        branch_to_epsilon_closure: FunctionIdx,
491        lookup_table: Option<LookupTable>,
492        state_id_layout: &Layout,
493    ) -> Function {
494        let mut locals_name_map = NameMap::new();
495        // Parameters
496        locals_name_map.append(0, "haystack_ptr");
497        locals_name_map.append(1, "haystack_len");
498        locals_name_map.append(2, "at_offset");
499        locals_name_map.append(3, "next_set_ptr");
500        locals_name_map.append(4, "next_set_len");
501        // Locals added in transition_fn_need_locals
502
503        fn transition_fn_need_locals(
504            for_sid: StateID,
505            state: &State,
506            locals_name_map: &mut NameMap,
507        ) -> Vec<(u32, ValType)> {
508            match state {
509                State::Fail
510                | State::Look { .. }
511                | State::Union { .. }
512                | State::BinaryUnion { .. }
513                | State::Capture { .. } => {
514                    // return None
515                    unreachable!(
516                        "We should never generate transitions for state [{for_sid:?}] since \
517                         they're excluded by `needs_transition_fn`."
518                    );
519                },
520                State::ByteRange { .. } | State::Sparse { .. } | State::Dense { .. } => {
521                    let mut num_i32s = 2;
522
523                    locals_name_map.append(5, "byte");
524                    locals_name_map.append(6, "next_state");
525
526                    if matches!(state, State::Sparse { .. } | State::Dense { .. }) {
527                        num_i32s += 1;
528                        locals_name_map.append(7, "loop_index");
529
530                        if matches!(state, State::Sparse { .. }) {
531                            num_i32s += 2;
532                            locals_name_map.append(8, "transition_start");
533                            locals_name_map.append(9, "transition_end");
534                        }
535                    }
536
537                    vec![(num_i32s, ValType::I32)]
538                },
539                State::Match { .. } => vec![],
540            }
541        }
542
543        let mut labels_name_map = NameMap::new();
544
545        let mut body = wasm_encoder::Function::new(transition_fn_need_locals(
546            for_sid,
547            &states[for_sid.as_usize()],
548            &mut locals_name_map,
549        ));
550        let mut instructions = body.instructions();
551        match &states[for_sid.as_usize()] {
552            State::Fail
553            | State::Look { .. }
554            | State::Union { .. }
555            | State::BinaryUnion { .. }
556            | State::Capture { .. } => {
557                // return None
558                unreachable!(
559                    "We should never generate transitions for state [{for_sid:?}] since they're \
560                     excluded by `needs_transition_fn`."
561                );
562            },
563            State::ByteRange { trans } => {
564                Self::non_terminal_transition_prefix(&mut instructions);
565                Self::byte_range_transition_body(&mut instructions, trans);
566                Self::non_terminal_transition_suffix(&mut instructions, branch_to_epsilon_closure);
567            },
568            State::Sparse(_) => {
569                // We don't need the transition data here, since we've already emitted the
570                // lookup tables
571                let sparse_table = lookup_table.unwrap().unwrap_sparse();
572                Self::non_terminal_transition_prefix(&mut instructions);
573                Self::sparse_transition_body(
574                    &mut instructions,
575                    sparse_table,
576                    &mut labels_name_map,
577                    state_id_layout,
578                );
579                Self::non_terminal_transition_suffix(&mut instructions, branch_to_epsilon_closure);
580            },
581            State::Dense(_) => {
582                // We don't need the transition data here, since we've already emitted the
583                // lookup tables
584                let dense_table = lookup_table.unwrap().unwrap_dense();
585                Self::non_terminal_transition_prefix(&mut instructions);
586                Self::dense_transition_body(&mut instructions, dense_table, state_id_layout);
587                Self::non_terminal_transition_suffix(&mut instructions, branch_to_epsilon_closure);
588            },
589            State::Match { .. } => {
590                // TODO: Need to update for pattern matches
591                // return Some(...)
592                instructions.local_get(4).i32_const(true as i32);
593            },
594        }
595        instructions.end();
596
597        Function {
598            sig: FunctionSignature {
599                name: format!("transition_s{}", for_sid.as_usize()),
600                // [haystack_ptr, haystack_len, at_offset, next_set_ptr, next_set_len]
601                params_ty: &[
602                    ValType::I64,
603                    ValType::I64,
604                    ValType::I64,
605                    ValType::I64,
606                    ValType::I32,
607                ],
608                // [new_next_set_len, is_match]
609                results_ty: &[ValType::I32, ValType::I32],
610                export: false,
611            },
612            def: FunctionDefinition {
613                body,
614                locals_name_map,
615                labels_name_map: Some(labels_name_map),
616                branch_hints: None,
617            },
618        }
619    }
620
621    /// Return true if the given state needs a transition function.
622    fn needs_transition_fn(nfa: &NFA, sid: StateID) -> bool {
623        matches!(
624            nfa.state(sid),
625            State::ByteRange { .. }
626                | State::Sparse { .. }
627                | State::Dense { .. }
628                | State::Match { .. }
629        )
630    }
631
632    fn non_terminal_transition_prefix(instructions: &mut InstructionSink<'_>) {
633        instructions // check haystack length load haystack byte
634            // if at_offset >= haystack_len
635            .local_get(2) // at_offset
636            .local_get(1) // haystack_len
637            .i64_ge_u()
638            .if_(BlockType::Empty)
639            // return None
640            .local_get(4) // next_set_len
641            .i32_const(false as i32)
642            .return_()
643            .end() // end if at_offset >= haystack_len
644            // TODO(opt): We can make haystack_ptr a constant if we thread through the input layout
645            // and expose it as a field instead. Since we always expect to place the input haystack
646            // in the same spot, while the haystack_len may vary
647            .local_get(0) // haystack_ptr
648            .local_get(2) // at_offset
649            .i64_add()
650            .i32_load8_u(MemArg {
651                offset: 0,
652                align: 0,
653                memory_index: 0,
654            })
655            .local_set(5); // byte
656    }
657
658    fn non_terminal_transition_suffix(
659        instructions: &mut InstructionSink<'_>,
660        branch_to_epsilon_closure: FunctionIdx,
661    ) {
662        instructions
663            // let at = at.wrapping_add(1);
664            .local_get(2) // at_offset
665            .i64_const(1)
666            .i64_add()
667            .local_set(2) // at_offset
668            // self.epsilon_closure(
669            //     stack, slots, next, input, at, trans.next,
670            // );
671            .local_get(0) // haystack_ptr
672            .local_get(1) // haystack_len
673            .local_get(2) // at_offset
674            .local_get(3) // next_set_ptr
675            .local_get(4) // next_set_len
676            .local_get(6) // next_state
677            // TODO(opt): Instead of calling indirectly to find get the right epsilon closure
678            // function, we could pipe through what the expected next state is and branch directly
679            // to that epsilon closure function if present (or add next state to the set if not
680            // present)
681            .call(branch_to_epsilon_closure.into()) // returns new_next_set_len
682            // return None
683            .i32_const(false as i32);
684    }
685
686    fn sparse_transition_body(
687        instructions: &mut InstructionSink<'_>,
688        sparse_table: SparseTable,
689        labels_name_map: &mut NameMap,
690        state_id_layout: &Layout,
691    ) {
692        // Range table is laid out as `[(start: u8, end: u8), ...]`
693        // Need to iterate and find the index where the `start` <= byte && byte
694        // <= end`.
695        //
696        // State table is just a flat array of state IDs like
697        // `[state ID, ...]` If we find a match in the range table, the
698        // corresponding index in the state table is the next transition
699        // state.
700
701        labels_name_map.append(1, "table_break_block");
702        labels_name_map.append(2, "table_iter_loop");
703
704        // loop_index = 0 // local 7
705        // loop {
706        //     if loop_index >= sparse_table.range_table_len {
707        //         return None;
708        //     }
709        //
710        //     start = range_table[loop_index].0 // local 8
711        //     if start > byte {
712        //         return None;
713        //     } else { // start <= byte
714        //         end = range_table[loop_index].1
715        //         if byte <= end {
716        //             next_state = state_table[loop_index]
717        //             break;
718        //         }
719        //     }
720        //     loop_index = loop_index + 1
721        // }
722        // ... continue to epsilon
723
724        instructions
725            .i32_const(0)
726            .local_set(7) // loop_index
727            // This block is needed so that we can break out of the loop
728            .block(BlockType::Empty)
729            .loop_(BlockType::Empty)
730            // if loop_index >= sparse_table.range_table_len {
731            .local_get(7) // loop_index
732            .i32_const(i32::from_ne_bytes(
733                u32::try_from(sparse_table.range_table_len)
734                    .unwrap()
735                    .to_ne_bytes(),
736            ))
737            .i32_ge_u()
738            .if_(BlockType::Empty)
739            // return None
740            .local_get(4) // next_set_len
741            .i32_const(false as i32)
742            .return_()
743            .end() // end if loop_index >= sparse_table.range_table_len {
744            // start = range_table[loop_index].0
745            .local_get(7) // loop_index
746            .i64_extend_i32_u()
747            .i64_const(i64::from_ne_bytes(
748                u64::try_from(sparse_table.range_lookup_table_stride)
749                    .unwrap()
750                    .to_ne_bytes(),
751            ))
752            .i64_mul()
753            .i32_load8_u(MemArg {
754                offset: u64::try_from(sparse_table.range_table_pos).unwrap(), // start is at offset 0
755                align: 0,
756                memory_index: 1,
757            })
758            .local_tee(8) // transition_start
759            // if start > byte {
760            .local_get(5) // byte
761            .i32_gt_u()
762            .if_(BlockType::Empty)
763            // return None
764            .local_get(4) // next_set_len
765            .i32_const(false as i32)
766            .return_()
767            // } else { // start <= byte
768            .else_()
769            // end = range_table[loop_index].1
770            .local_get(7) // loop_index
771            .i64_extend_i32_u()
772            .i64_const(i64::from_ne_bytes(
773                u64::try_from(sparse_table.range_lookup_table_stride)
774                    .unwrap()
775                    .to_ne_bytes(),
776            ))
777            .i64_mul()
778            .i32_load8_u(MemArg {
779                offset: u64::try_from(sparse_table.range_table_pos).unwrap() + 1, // end is at offset 1
780                align: 0,
781                memory_index: 1,
782            })
783            .local_set(9) // transition_end
784            // if byte <= end {
785            .local_get(5) // byte
786            .local_get(9) // transition_end
787            .i32_le_u()
788            .if_(BlockType::Empty)
789            // next_state = state_table[loop_index]
790            .local_get(7) // loop_index
791            .i64_extend_i32_u()
792            .i64_const(i64::from_ne_bytes(
793                u64::try_from(sparse_table.state_id_table_stride)
794                    .unwrap()
795                    .to_ne_bytes(),
796            ))
797            .i64_mul()
798            .state_id_load(u64::try_from(sparse_table.state_id_table_pos).unwrap(),  state_id_layout)
799            .local_set(6) // next_state
800            // break;
801            // jump to the end of the block outside of loop
802            // Depth: 0=inner if, 1=outer if, 2=`loop`, 3=enclosing `block`
803            .br(3)
804            .end() // end if byte <= end {
805            .end() // end } else { // start <= byte
806            // loop_index = loop_index + 1
807            .local_get(7) // loop_index
808            .i32_const(1)
809            .i32_add()
810            .local_set(7) // loop_index
811            .br(0)
812            .end() // end loop
813            .end(); // end block
814    }
815
816    fn dense_transition_body(
817        instructions: &mut InstructionSink<'_>,
818        table: DenseTable,
819        state_id_layout: &Layout,
820    ) {
821        // Dense transition table is laid out as a 256-length array of state
822        // IDs. To lookup the next state, just use the byte as an index. If the
823        // state is non-zero, then the transition is present.
824
825        instructions
826            .local_get(5) // byte
827            .i64_extend_i32_u()
828            .i64_const(i64::from_ne_bytes(
829                u64::try_from(table.table_stride).unwrap().to_ne_bytes(),
830            ))
831            .i64_mul() // offset in table
832            .state_id_load(u64::try_from(table.table_pos).unwrap(), state_id_layout)
833            .local_tee(6) // next_state
834            // if next == StateID::ZERO
835            .i32_eqz()
836            .if_(BlockType::Empty)
837            // return None
838            .local_get(4) // next_set_len
839            .i32_const(false as i32)
840            .return_()
841            .end();
842    }
843
844    fn byte_range_transition_body(instructions: &mut InstructionSink<'_>, trans: &Transition) {
845        instructions
846            // self.start <= byte
847            .i32_const(trans.start.into()) // self.start
848            .local_get(5) // byte
849            // we invert the condition here, since we're testing the failure
850            .i32_gt_u() // >
851            .if_(BlockType::Empty)
852            .local_get(4) // next_set_len
853            .i32_const(false as i32)
854            .return_()
855            .end()
856            // byte <= self.end
857            .local_get(5) // byte
858            .i32_const(trans.end.into()) // self.end
859            // we invert the condition here, since we're testing the failure
860            .i32_gt_u() // >
861            .if_(BlockType::Empty)
862            .local_get(4) // next_set_len
863            .i32_const(false as i32)
864            .return_()
865            .end()
866            .i32_const(i32::from_ne_bytes(trans.next.as_u32().to_ne_bytes()))
867            .local_set(6); // next_state
868    }
869}
870
871#[cfg(test)]
872mod tests {
873    use crate::compile::{
874        lookaround::{LookFunctions, LookLayout},
875        sparse_set::{tests::get_sparse_set_fns, SparseSetFunctions, SparseSetLayout},
876        tests::setup_interpreter,
877    };
878
879    use super::*;
880
881    fn branch_to_transition_test_closure(
882        nfa: NFA,
883        haystack: &[u8],
884    ) -> impl FnMut(i32, usize, &[u8], bool) + '_ {
885        let mut ctx = CompileContext::new(
886            nfa,
887            crate::Config::new()
888                .export_all_functions(true)
889                .export_state(true),
890        );
891
892        // We're going to assume all states use less then u8::MAX states
893        assert_eq!(*ctx.state_id_layout(), Layout::new::<u8>());
894
895        let overall = Layout::new::<()>();
896        let (overall, sparse_set_layout) = SparseSetLayout::new(&mut ctx, overall).unwrap();
897        let sparse_set_functions = SparseSetFunctions::new(&mut ctx, &sparse_set_layout);
898        let (overall, look_layout) = LookLayout::new(&mut ctx, overall).unwrap();
899        let look_funcs = LookFunctions::new(&mut ctx, &look_layout).unwrap();
900        let epsilon_closures =
901            EpsilonClosureFunctions::new(&mut ctx, sparse_set_functions.insert, &look_funcs)
902                .unwrap();
903        let (overall, transition_layout) = TransitionLayout::new(&mut ctx, overall).unwrap();
904        let _transition_functions =
905            TransitionFunctions::new(&mut ctx, &epsilon_closures, &transition_layout);
906
907        let module_bytes = ctx.compile(&overall).finish();
908        let (_engine, _module, mut store, instance) = setup_interpreter(&module_bytes);
909
910        let branch_to_transition = instance
911            .get_typed_func::<(i64, i64, i64, i64, i32, i32), (i32, i32)>(
912                &store,
913                "branch_to_transition",
914            )
915            .unwrap();
916
917        let haystack_memory = instance.get_memory(&store, "haystack").unwrap();
918        let state_memory = instance.get_memory(&store, "state").unwrap();
919
920        // Write haystack byte into memory ahead of transition call
921        haystack_memory.data_mut(&mut store)[0..haystack.len()].copy_from_slice(haystack);
922
923        move |state_id: i32,
924              at_offset: usize,
925              expected_next_states: &[u8],
926              expected_is_match: bool| {
927            let haystack_ptr = 0;
928            let haystack_len = haystack.len() as i64;
929            let set_ptr = 0;
930            let set_len = 0;
931
932            let (new_set_len, is_match) = branch_to_transition
933                .call(
934                    &mut store,
935                    (
936                        haystack_ptr,
937                        haystack_len,
938                        at_offset as i64,
939                        set_ptr,
940                        set_len,
941                        state_id,
942                    ),
943                )
944                .unwrap();
945
946            let byte = haystack.get(at_offset).copied().unwrap_or(u8::MAX);
947
948            assert_eq!(
949                is_match, expected_is_match as i32,
950                "{state_id} @ {at_offset} => {byte}/{}",
951                byte as char
952            );
953
954            let states = &unsafe { state_memory.data(&store).align_to::<u8>().1 }
955                [0..usize::try_from(new_set_len).unwrap()];
956            assert_eq!(
957                states, expected_next_states,
958                "{state_id} @ {at_offset} => {byte}/{}",
959                byte as char
960            );
961        }
962    }
963
964    #[test]
965    fn branch_to_normal_transition() {
966        // thompson::NFA(
967        //     >000000: binary-union(2, 1)
968        //      000001: \x00-\xFF => 0
969        //     ^000002: capture(pid=0, group=0, slot=0) => 3
970        //      000003: a => 4
971        //      000004: b => 5
972        //      000005: c => 6
973        //      000006: binary-union(3, 7)
974        //      000007: capture(pid=0, group=0, slot=1) => 8
975        //      000008: MATCH(0)
976        let nfa = NFA::new("(?:abc)+").unwrap();
977
978        let mut test = branch_to_transition_test_closure(nfa, b"abc");
979
980        // State 0:
981        test(0, 0, &[], false);
982        test(0, 1, &[], false);
983        test(0, 2, &[], false);
984
985        // State 1: \x00-\xFF => 0
986        test(1, 0, &[0, 1, 2, 3], false);
987        test(1, 1, &[0, 1, 2, 3], false);
988        test(1, 2, &[0, 1, 2, 3], false);
989
990        // State 2: capture(pid=0, group=0, slot=0) => 3
991        test(2, 0, &[], false);
992        test(2, 1, &[], false);
993        test(2, 2, &[], false);
994
995        // State 3: a => 4
996        test(3, 0, &[4], false);
997        test(3, 1, &[], false);
998        test(3, 2, &[], false);
999
1000        // State 4: b => 5
1001        test(4, 0, &[], false);
1002        test(4, 1, &[5], false);
1003        test(4, 2, &[], false);
1004
1005        // State 5: c => 6 + epsilon transitions
1006        test(5, 0, &[], false);
1007        test(5, 1, &[], false);
1008        test(5, 2, &[3, 6, 7, 8], false);
1009
1010        // State 6: binary-union(3, 7)
1011        test(6, 0, &[], false);
1012        test(6, 1, &[], false);
1013        test(6, 2, &[], false);
1014
1015        // State 7: capture(pid=0, group=0, slot=1) => 8
1016        test(7, 0, &[], false);
1017        test(7, 1, &[], false);
1018        test(7, 2, &[], false);
1019
1020        // State 8: MATCH(0)
1021        test(8, 0, &[], true);
1022    }
1023
1024    #[test]
1025    fn branch_to_sparse_transition() {
1026        // thompson::NFA(
1027        //     >000000: binary-union(2, 1)
1028        //      000001: \x00-\xFF => 0
1029        //     ^000002: capture(pid=0, group=0, slot=0) => 6
1030        //      000003: c => 7
1031        //      000004: c => 7
1032        //      000005: c => 7
1033        //      000006: sparse(a => 3, b => 4, d => 5, e => 7, g => 7)
1034        //      000007: capture(pid=0, group=0, slot=1) => 8
1035        //      000008: MATCH(0)
1036        let nfa = NFA::new("ac|bc|dc|e|g").unwrap();
1037
1038        let mut test = branch_to_transition_test_closure(nfa, b"acbcdceg");
1039
1040        // State 0: binary-union(2, 1)
1041        for offset in [0, 2, 4, 6, 7] {
1042            test(0, offset, &[], false);
1043        }
1044
1045        // State 1: \x00-\xFF => 0
1046        for offset in [0, 2, 4, 6, 7] {
1047            test(1, offset, &[0, 1, 2, 6], false);
1048        }
1049
1050        // State 2: capture(pid=0, group=0, slot=0) => 8
1051        for offset in [0, 2, 4, 6, 7] {
1052            test(2, offset, &[], false);
1053        }
1054
1055        for state in [3, 4, 5] {
1056            // State 3/4/5: c => 7
1057            test(state, 0, &[], false);
1058            test(state, 2, &[], false);
1059            test(state, 1, &[7, 8], false);
1060        }
1061
1062        // State 6: sparse(a => 3, b => 4, d => 5, e => 7, g => 7)
1063        test(6, 0, &[3], false);
1064        test(6, 2, &[4], false);
1065        test(6, 4, &[5], false);
1066        test(6, 6, &[7, 8], false);
1067        test(6, 7, &[7, 8], false);
1068
1069        // State 7: capture(pid=0, group=0, slot=1) => 8
1070        for offset in [0, 2, 4, 6, 7] {
1071            test(7, offset, &[], false);
1072        }
1073
1074        // State 8: MATCH(0)
1075        for offset in [0, 2, 4, 6, 7] {
1076            test(8, offset, &[], true);
1077        }
1078    }
1079
1080    #[test]
1081    fn branch_to_simple_lookaround_transitions() {
1082        // thompson::NFA(
1083        // ^000000: capture(pid=0, group=0, slot=0) => 1
1084        //  000001: Start => 2
1085        //  000002: h => 3
1086        //  000003: e => 4
1087        //  000004: l => 5
1088        //  000005: l => 6
1089        //  000006: ' ' => 7
1090        //  000007: w => 8
1091        //  000008: o => 9
1092        //  000009: r => 10
1093        //  000010: m => 11
1094        //  000011: End => 12
1095        //  000012: capture(pid=0, group=0, slot=1) => 13
1096        //  000013: MATCH(0)
1097        let nfa = NFA::new("^hell worm$").unwrap();
1098
1099        let mut test = branch_to_transition_test_closure(nfa, b"hell worm");
1100
1101        test(0, 0, &[], false);
1102        test(1, 0, &[], false);
1103        test(2, 0, &[3], false);
1104        test(3, 1, &[4], false);
1105        test(4, 2, &[5], false);
1106        test(5, 3, &[6], false);
1107        test(6, 4, &[7], false);
1108        test(7, 5, &[8], false);
1109        test(8, 6, &[9], false);
1110        test(9, 7, &[10], false);
1111        test(10, 8, &[11, 12, 13], false);
1112        test(11, 9, &[], false);
1113        test(12, 0, &[], false);
1114        test(13, 0, &[], true);
1115    }
1116
1117    // It seems like `DenseTransitions` are not constructed in the internal
1118    // `regex-automata` code
1119
1120    fn make_current_transitions_test_closure(
1121        nfa: NFA,
1122    ) -> impl FnMut(&[i32], u8, Option<&[u8]>, bool) {
1123        let mut ctx = CompileContext::new(
1124            nfa,
1125            crate::Config::new()
1126                .export_all_functions(true)
1127                .export_state(true),
1128        );
1129
1130        // Assume all tests use less than 255 states
1131        assert_eq!(ctx.state_id_layout(), &Layout::new::<u8>());
1132
1133        let overall = Layout::new::<()>();
1134
1135        let (overall, current_set_layout) = SparseSetLayout::new(&mut ctx, overall).unwrap();
1136        let (overall, next_set_layout) = SparseSetLayout::new(&mut ctx, overall).unwrap();
1137        let (overall, look_layout) = LookLayout::new(&mut ctx, overall).unwrap();
1138
1139        let sparse_set_functions = SparseSetFunctions::new(&mut ctx, &current_set_layout);
1140        let look_funcs = LookFunctions::new(&mut ctx, &look_layout).unwrap();
1141
1142        let epsilon_closures =
1143            EpsilonClosureFunctions::new(&mut ctx, sparse_set_functions.insert, &look_funcs)
1144                .unwrap();
1145
1146        let (overall, transition_layout) = TransitionLayout::new(&mut ctx, overall).unwrap();
1147        let _transition_functions =
1148            TransitionFunctions::new(&mut ctx, &epsilon_closures, &transition_layout);
1149
1150        let module_bytes = ctx.compile(&overall).finish();
1151        let (_engine, _module, mut store, instance) = setup_interpreter(&module_bytes);
1152
1153        let make_current_transitions = instance
1154            // [haystack_ptr, haystack_len, at_offset, current_set_ptr, current_set_len,
1155            // next_set_ptr, next_set_len]
1156            .get_typed_func::<(i64, i64, i64, i64, i32, i64, i32), (i32, i32)>(
1157                &store,
1158                "make_current_transitions",
1159            )
1160            .unwrap();
1161
1162        let (_, set_insert) = get_sparse_set_fns(&instance, &store);
1163
1164        let haystack_memory = instance.get_memory(&store, "haystack").unwrap();
1165        let state_memory = instance.get_memory(&store, "state").unwrap();
1166
1167        move |current_states: &[i32],
1168              byte: u8,
1169              expected_next_states: Option<&[u8]>,
1170              expected_is_match: bool| {
1171            let haystack_ptr = 0;
1172            let haystack_len = 1;
1173            let at_offset = 0;
1174            let current_set_ptr = current_set_layout.set_start_pos as i64;
1175            let mut current_set_len = 0;
1176            let next_set_ptr = next_set_layout.set_start_pos as i64;
1177            let next_set_len = 0;
1178
1179            // Write haystack byte into memory ahead of transition call
1180            haystack_memory.data_mut(&mut store)[haystack_ptr as usize + at_offset as usize] = byte;
1181            // Write all current states into set
1182            for state in current_states {
1183                current_set_len = set_insert
1184                    .call(&mut store, (current_set_len, *state, current_set_ptr))
1185                    .unwrap();
1186            }
1187
1188            let (new_next_set_len, is_match) = make_current_transitions
1189                .call(
1190                    &mut store,
1191                    (
1192                        haystack_ptr,
1193                        haystack_len,
1194                        at_offset,
1195                        current_set_ptr,
1196                        current_set_len,
1197                        next_set_ptr,
1198                        next_set_len,
1199                    ),
1200                )
1201                .unwrap();
1202
1203            assert_eq!(
1204                is_match, expected_is_match as i32,
1205                "{current_states:?} => {byte}/{}",
1206                byte as char
1207            );
1208
1209            if let Some(expected_next_states) = expected_next_states {
1210                assert_eq!(
1211                    new_next_set_len,
1212                    expected_next_states.len() as i32,
1213                    "{current_states:?} => {byte}/{}",
1214                    byte as char
1215                );
1216                let states = &unsafe {
1217                    state_memory.data(&store)[next_set_layout.set_start_pos
1218                        ..(next_set_layout.set_start_pos + next_set_layout.set_overall.size())]
1219                        .align_to::<u8>()
1220                        .1
1221                }[0..expected_next_states.len()];
1222                assert_eq!(
1223                    states, expected_next_states,
1224                    "{current_states:?} => {byte}/{}",
1225                    byte as char
1226                );
1227            } else {
1228                assert_eq!(
1229                    new_next_set_len, 0,
1230                    "{current_states:?} => {byte}/{}",
1231                    byte as char
1232                );
1233            }
1234        }
1235    }
1236
1237    #[test]
1238    fn make_current_transitions_normal() {
1239        // thompson::NFA(
1240        //     >000000: binary-union(2, 1)
1241        //      000001: \x00-\xFF => 0
1242        //     ^000002: capture(pid=0, group=0, slot=0) => 3
1243        //      000003: a => 4
1244        //      000004: b => 5
1245        //      000005: c => 6
1246        //      000006: binary-union(3, 7)
1247        //      000007: capture(pid=0, group=0, slot=1) => 8
1248        //      000008: MATCH(0)
1249        let nfa = NFA::new("(?:abc)+").unwrap();
1250
1251        let mut test = make_current_transitions_test_closure(nfa);
1252
1253        // No states:
1254        test(&[], b'a', None, false);
1255        test(&[], b'b', None, false);
1256        test(&[], b'c', None, false);
1257
1258        // Initial states
1259        test(&[0, 1, 2, 3], b'a', Some(&[0, 1, 2, 3, 4]), false);
1260        test(&[0, 1, 2, 3], b'b', Some(&[0, 1, 2, 3]), false);
1261        test(&[0, 1, 2, 3], b'c', Some(&[0, 1, 2, 3]), false);
1262
1263        // Normal transition states
1264        test(&[3, 4, 5], b'a', Some(&[4]), false);
1265        test(&[3, 4, 5], b'b', Some(&[5]), false);
1266        test(&[3, 4, 5], b'c', Some(&[3, 6, 7, 8]), false);
1267
1268        // Terminal transition states
1269        test(&[6, 7], b'a', None, false);
1270        test(&[6, 7], b'b', None, false);
1271        test(&[6, 7], b'c', None, false);
1272
1273        // Success transition states
1274        test(&[6, 7, 8], b'a', None, true);
1275        test(&[6, 7, 8], b'b', None, true);
1276        test(&[6, 7, 8], b'c', None, true);
1277
1278        // Mixed states
1279        test(&[3, 4, 5, 8], b'a', Some(&[4]), true);
1280        test(&[3, 4, 5, 8], b'b', Some(&[5]), true);
1281        test(&[3, 4, 5, 8], b'c', Some(&[3, 6, 7, 8]), true);
1282    }
1283}