wahgex_core/compile/
epsilon_closure.rs

1//! This module contains types and functions related to computing the epsilon
2//! closure of a given NFA state.
3
4use std::collections::{HashMap, HashSet};
5
6use regex_automata::{
7    nfa::thompson::State,
8    util::{look::Look, primitives::StateID},
9};
10use wasm_encoder::{BlockType, NameMap, ValType};
11
12use super::{
13    context::{Function, FunctionDefinition, FunctionIdx, FunctionSignature},
14    lookaround::LookFunctions,
15    BuildError, CompileContext,
16};
17
18/// This struct contains a map of functions that are the pre-computed epsilon
19/// closure for each NFA state.
20#[derive(Debug)]
21pub struct EpsilonClosureFunctions {
22    state_closures: HashMap<StateID, FunctionIdx>,
23    pub branch_to_epsilon_closure: FunctionIdx,
24}
25
26impl EpsilonClosureFunctions {
27    /// Create a new set of epsilon closure functions for the given input.
28    pub fn new(
29        ctx: &mut CompileContext,
30        sparse_set_insert: FunctionIdx,
31        look_funcs: &LookFunctions,
32    ) -> Result<Self, BuildError> {
33        let state_closures = Self::all_epsilon_closure_fns(ctx, sparse_set_insert, look_funcs)?;
34        let branch_to_epsilon_closure = ctx.add_function(Self::branch_to_epsilon_closure_fn(
35            &state_closures,
36            sparse_set_insert,
37        ));
38
39        Ok(Self {
40            state_closures,
41            branch_to_epsilon_closure,
42        })
43    }
44
45    fn all_epsilon_closure_fns(
46        ctx: &mut CompileContext,
47        sparse_set_insert: FunctionIdx,
48        look_funcs: &LookFunctions,
49    ) -> Result<HashMap<StateID, FunctionIdx>, BuildError> {
50        // NOTE: The indexes of the `states` array correspond to the `StateID` value.
51        let mut state_to_epsilon_closure_fn = HashMap::new();
52
53        let num_states = ctx.nfa.states().len();
54        for for_sid in (0..num_states).map(StateID::new).map(Result::unwrap) {
55            let states = ctx.nfa.states();
56            let closure = compute_epsilon_closure(for_sid, states)?;
57            if Self::can_omit_epsilon_closure(&closure, for_sid) {
58                continue;
59            }
60
61            let sig = Self::epsilon_closure_fn_sig(for_sid);
62            let func_idx = ctx.declare_function(sig);
63
64            state_to_epsilon_closure_fn.insert(for_sid, func_idx);
65        }
66
67        for (for_sid, func_idx) in &state_to_epsilon_closure_fn {
68            let states = ctx.nfa.states();
69            let closure = compute_epsilon_closure(*for_sid, states)?;
70            let def = Self::epsilon_closure_fn_def(
71                closure,
72                &state_to_epsilon_closure_fn,
73                sparse_set_insert,
74                look_funcs,
75            )?;
76            ctx.define_function(*func_idx, def);
77        }
78
79        Ok(state_to_epsilon_closure_fn)
80    }
81
82    /// Get the epsilon closure function for the given state ID, if present.
83    #[expect(dead_code)]
84    pub fn get(&self, sid: StateID) -> Option<FunctionIdx> {
85        self.state_closures.get(&sid).copied()
86    }
87
88    fn branch_to_epsilon_closure_fn(
89        epsilon_closures: &HashMap<StateID, FunctionIdx>,
90        sparse_set_insert: FunctionIdx,
91    ) -> Function {
92        let mut locals_name_map = NameMap::new();
93        // Parameters
94        locals_name_map.append(0, "haystack_ptr");
95        locals_name_map.append(1, "haystack_len");
96        locals_name_map.append(2, "at_offset");
97        locals_name_map.append(3, "next_set_ptr");
98        locals_name_map.append(4, "next_set_len");
99        locals_name_map.append(5, "state_id");
100
101        let mut body = wasm_encoder::Function::new([]);
102        let mut instructions = body.instructions();
103
104        let mut states = epsilon_closures.keys().copied().collect::<Vec<_>>();
105        states.sort();
106
107        // This loop will cover any states where [`Self::can_omit_epsilon_closure`]
108        // returned false. All other states will fall through to the code below
109        // which inserts only the self-state.
110        for sid in states {
111            let epsilon_closure_fn = epsilon_closures.get(&sid).copied().unwrap();
112            instructions
113                .local_get(5)
114                .i32_const(i32::from_ne_bytes(sid.as_u32().to_ne_bytes()))
115                .i32_eq()
116                .if_(BlockType::Empty)
117                .local_get(0)
118                .local_get(1)
119                .local_get(2)
120                .local_get(3)
121                .local_get(4)
122                .call(epsilon_closure_fn.into())
123                .return_()
124                .end();
125        }
126
127        // If it falls through to this point, then we must assume thats its a state
128        // which has no epsilon transitions. In which case, we need to add the current
129        // state to the next set and return.
130        instructions
131            .local_get(4) // next_set_len
132            .local_get(5) // state_id
133            .local_get(3) // next_set_ptr
134            .call(sparse_set_insert.into())
135            .end();
136
137        Function {
138            sig: FunctionSignature {
139                name: "branch_to_epsilon_closure".into(),
140                // [haystack_ptr, haystack_len, at_offset, next_set_ptr, next_set_len, state_id]
141                params_ty: &[
142                    // TODO(opt): Remove haystack_ptr and assume that haystack always starts at
143                    // offset 0 in memory 0
144                    ValType::I64,
145                    ValType::I64,
146                    ValType::I64,
147                    ValType::I64,
148                    ValType::I32,
149                    ValType::I32,
150                ],
151                // [new_next_set_len]
152                results_ty: &[ValType::I32],
153                export: false,
154            },
155            def: FunctionDefinition {
156                body,
157                locals_name_map,
158                labels_name_map: None,
159                branch_hints: None,
160            },
161        }
162    }
163
164    /// Return true if we can omit the epsilon closure function for the given
165    /// state and closure.
166    ///
167    /// We can omit epsilon closures which only contain the self-state, since
168    /// branch_to_epsilon_closure will always include a default branch to
169    /// populate the singleton set.
170    fn can_omit_epsilon_closure(closure: &EpsilonClosure, for_sid: StateID) -> bool {
171        closure.unconditional.len() == 1
172            && closure.unconditional.contains(&for_sid)
173            // Return false if there are conditional lookaround transitions from for_sid
174            && closure.lookaround.is_empty()
175    }
176
177    fn epsilon_closure_fn_sig(for_sid: StateID) -> FunctionSignature {
178        FunctionSignature {
179            name: format!("epsilon_closure_s{}", for_sid.as_usize()),
180            // [haystack_ptr, haystack_len, at_offset, next_set_ptr, next_set_len]
181            params_ty: &[
182                ValType::I64,
183                ValType::I64,
184                ValType::I64,
185                ValType::I64,
186                ValType::I32,
187            ],
188            // [new_next_set_len]
189            results_ty: &[ValType::I32],
190            export: false,
191        }
192    }
193
194    fn epsilon_closure_fn_def(
195        closure: EpsilonClosure,
196        state_to_epsilon_closure_fn: &HashMap<StateID, FunctionIdx>,
197        sparse_set_insert: FunctionIdx,
198        look_funcs: &LookFunctions,
199    ) -> Result<FunctionDefinition, BuildError> {
200        let mut unconditional = closure.unconditional.into_iter().collect::<Vec<_>>();
201        // need this to keep consistency of snapshot tests
202        unconditional.sort();
203
204        let mut locals_name_map = NameMap::new();
205        // Parameters
206        locals_name_map.append(0, "haystack_ptr");
207        locals_name_map.append(1, "haystack_len");
208        locals_name_map.append(2, "at_offset");
209        locals_name_map.append(3, "next_set_ptr");
210        locals_name_map.append(4, "next_set_len");
211        // Locals
212        locals_name_map.append(5, "new_next_set_len");
213
214        let mut body = wasm_encoder::Function::new([(1, ValType::I32)]);
215        let mut instructions = body.instructions();
216        // TODO: `haystack_ptr`, `haystack_len`, and `at_offset` will be unused until we
217        // support lookaround and need to check the stack
218
219        instructions.local_get(4);
220
221        // TODO(opt): Could optimize this by adding a bulk insert method and loading all
222        // of these from a memory location initialized by an active data segment
223        for closure_sid in unconditional {
224            instructions
225                // new_next_set_len is already on the stack from the prelude or the previous call to
226                // sparse_set_insert
227                .i32_const(i32::from_ne_bytes(closure_sid.as_u32().to_ne_bytes()))
228                .local_get(3) // next_set_ptr
229                // TODO(opt): Instead of creating a separate function for every state's epsilon
230                // transition, have some of them be inlined depending on size.
231                .call(sparse_set_insert.into());
232        }
233
234        // At this point the stack is [new_next_set_len]
235
236        // Implementation strategy for lookaround:
237        //  1. For epsilon transitions that include a `Look`, add a conditional block
238        //     after inserting all the unconditional states. The block should be keyed
239        //     on whether or not new states were added to the next_set.
240        //  2. Inside the block, we should have the actual `Look` conditionals, based on
241        //     the haystack.
242        //  3. If the look conditional passes, then recurse into the epsilon closure
243        //     function of the `next` state. If that function was omitted (see
244        //     `can_omit_epsilon_closure`) then just emit some code that adds the `next`
245        //     state to the `next_set`.
246
247        if !closure.lookaround.is_empty() {
248            instructions
249                .local_tee(5)
250                .local_get(4)
251                .i32_ne()
252                .if_(BlockType::Empty);
253            for look in closure.lookaround {
254                instructions
255                    .local_get(0)
256                    .local_get(1)
257                    .local_get(2)
258                    .call(look_funcs.look_matcher(look.look).unwrap().into())
259                    .if_(BlockType::Empty);
260                // conditional look did match, now call into epsilon transition
261                if let Some(epsilon_closure_fn_idx) =
262                    state_to_epsilon_closure_fn.get(&look.next).copied()
263                {
264                    // Recursive call to the next state's epsilon closure fn
265                    instructions
266                        // Args needed [haystack_ptr, haystack_len, at_offset, next_set_ptr,
267                        // new_next_set_len]
268                        .local_get(0)
269                        .local_get(1)
270                        .local_get(2)
271                        .local_get(3)
272                        .local_get(5)
273                        .call(epsilon_closure_fn_idx.into())
274                        .local_set(5);
275                } else {
276                    // Single state insert
277                    instructions
278                        // Args needed [new_next_set_len, state_id, next_set_ptr]
279                        .local_get(5)
280                        .i32_const(i32::from_ne_bytes(look.next.as_u32().to_ne_bytes()))
281                        .local_get(3)
282                        .call(sparse_set_insert.into())
283                        .local_set(5);
284                }
285
286                instructions.end();
287            }
288
289            instructions.end().local_get(5);
290        }
291
292        instructions.end();
293
294        Ok(FunctionDefinition {
295            body,
296            locals_name_map,
297            labels_name_map: None,
298            branch_hints: None,
299        })
300    }
301}
302
303#[derive(Debug)]
304struct EpsilonClosure {
305    /// This is the set of states that are unconditionally epsilon-reachable.
306    ///
307    /// This is contrast to those states that are conditionally
308    /// epsilon-reachable through a [`State::Look`] (lookaround).
309    unconditional: HashSet<StateID>,
310    /// This is the list of lookaround states that are directly reachable from
311    /// the `pure` set with no conditional epsilon transitions.
312    lookaround: Vec<EpsilonLook>,
313}
314
315#[derive(Debug, Clone, Copy, PartialEq, Eq)]
316struct EpsilonLook {
317    next: StateID,
318    look: Look,
319}
320
321fn compute_epsilon_closure(sid: StateID, states: &[State]) -> Result<EpsilonClosure, BuildError> {
322    let mut unconditional: HashSet<_> = HashSet::new();
323
324    let mut lookaround = Vec::new();
325
326    let mut stack = vec![sid];
327    'stack: while let Some(mut sid) = stack.pop() {
328        loop {
329            if !unconditional.insert(sid) {
330                continue 'stack;
331            }
332
333            match &states[sid.as_usize()] {
334                State::Fail
335                | State::Match { .. }
336                | State::ByteRange { .. }
337                | State::Sparse { .. }
338                | State::Dense { .. } => {
339                    // TODO: Need to integrate here for slot/matching support
340                    continue 'stack;
341                },
342                State::Look { look, next } => {
343                    lookaround.push(EpsilonLook {
344                        next: *next,
345                        look: *look,
346                    });
347                },
348                State::Union { alternates } => {
349                    sid = match alternates.first() {
350                        None => continue 'stack,
351                        Some(&sid) => sid,
352                    };
353                    stack.extend(alternates[1..].iter().copied().rev());
354                },
355                State::BinaryUnion { alt1, alt2 } => {
356                    sid = *alt1;
357                    stack.push(*alt2);
358                },
359                State::Capture { next, .. } => {
360                    // TODO: Need to integrate here for slot/matching support
361                    sid = *next;
362                },
363            }
364        }
365    }
366
367    Ok(EpsilonClosure {
368        unconditional,
369        lookaround,
370    })
371}
372
373#[cfg(test)]
374mod tests {
375    use std::alloc::Layout;
376
377    use regex_automata::nfa::thompson::NFA;
378
379    use crate::compile::{
380        lookaround::LookLayout,
381        sparse_set::{SparseSetFunctions, SparseSetLayout},
382        tests::setup_interpreter,
383    };
384
385    use super::*;
386
387    #[test]
388    fn test_epsilon_closures() {
389        let re = NFA::new("(Hello)* world").unwrap();
390        // thompson::NFA(
391        //     >000000: binary-union(2, 1)
392        //      000001: \x00-\xFF => 0
393        //     ^000002: capture(pid=0, group=0, slot=0) => 3
394        //      000003: binary-union(4, 11)
395        //      000004: capture(pid=0, group=1, slot=2) => 5
396        //      000005: H => 6
397        //      000006: e => 7
398        //      000007: l => 8
399        //      000008: l => 9
400        //      000009: o => 10
401        //      000010: capture(pid=0, group=1, slot=3) => 3
402        //      000011: ' ' => 12
403        //      000012: w => 13
404        //      000013: o => 14
405        //      000014: r => 15
406        //      000015: l => 16
407        //      000016: d => 17
408        //      000017: capture(pid=0, group=0, slot=1) => 18
409        //      000018: MATCH(0)
410
411        let test = |sid: StateID, expected_states: &[usize]| {
412            let closure = compute_epsilon_closure(sid, re.states()).unwrap();
413            assert_eq!(
414                closure.unconditional,
415                expected_states
416                    .iter()
417                    .copied()
418                    .map(StateID::new)
419                    .map(Result::unwrap)
420                    .collect(),
421                "Closure from state {sid:?} on:\n{re:?}",
422            );
423        };
424
425        test(StateID::ZERO, &[2, 1, 4, 11, 3, 0, 5]);
426        test(StateID::new(3).unwrap(), &[3, 4, 5, 11]);
427        test(StateID::new(4).unwrap(), &[4, 5]);
428        test(StateID::new(5).unwrap(), &[5]);
429    }
430
431    #[test]
432    fn test_large_union_epsilon_closure() {
433        let re = NFA::new("a*|b*|c*|d*|e*").unwrap();
434        // thompson::NFA(
435        //     >000000: binary-union(2, 1)
436        //      000001: \x00-\xFF => 0
437        //     ^000002: capture(pid=0, group=0, slot=0) => 7
438        //      000003: binary-union(4, 14)
439        //      000004: a => 3
440        //      000005: binary-union(6, 14)
441        //      000006: b => 5
442        //      000007: union(3, 5, 8, 10, 12)
443        //      000008: binary-union(9, 14)
444        //      000009: c => 8
445        //      000010: binary-union(11, 14)
446        //      000011: d => 10
447        //      000012: binary-union(13, 14)
448        //      000013: e => 12
449        //      000014: capture(pid=0, group=0, slot=1) => 15
450        //      000015: MATCH(0)
451
452        let test = |sid: StateID, expected_states: &[usize]| {
453            let closure = compute_epsilon_closure(sid, re.states()).unwrap();
454            assert_eq!(
455                closure.unconditional,
456                expected_states
457                    .iter()
458                    .copied()
459                    .map(StateID::new)
460                    .map(Result::unwrap)
461                    .collect(),
462                "Closure from state {sid:?} on:\n{re:?}",
463            );
464        };
465
466        test(
467            StateID::new(7).unwrap(),
468            &[3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15],
469        );
470        test(StateID::new(3).unwrap(), &[3, 4, 14, 15]);
471        test(StateID::new(4).unwrap(), &[4]);
472        test(StateID::new(14).unwrap(), &[14, 15]);
473    }
474
475    #[test]
476    fn lookaround_epsilon_closure_panic() {
477        let re = NFA::new(r"^hell (?:worm$|world)").unwrap();
478        // thompson::NFA(
479        // ^000000: capture(pid=0, group=0, slot=0) => 1
480        //  000001: Start => 2
481        //  000002: h => 3
482        //  000003: e => 4
483        //  000004: l => 5
484        //  000005: l => 6
485        //  000006: ' ' => 17
486        //  000007: w => 8
487        //  000008: o => 9
488        //  000009: r => 10
489        //  000010: m => 11
490        //  000011: End => 18
491        //  000012: w => 13
492        //  000013: o => 14
493        //  000014: r => 15
494        //  000015: l => 16
495        //  000016: d => 18
496        //  000017: binary-union(7, 12)
497        //  000018: capture(pid=0, group=0, slot=1) => 19
498        //  000019: MATCH(0)
499
500        {
501            let closure = compute_epsilon_closure(StateID::new(0).unwrap(), re.states()).unwrap();
502
503            assert_eq!(
504                closure.unconditional,
505                [0, 1]
506                    .iter()
507                    .copied()
508                    .map(StateID::new)
509                    .map(Result::unwrap)
510                    .collect()
511            );
512            assert_eq!(
513                closure.lookaround,
514                vec![EpsilonLook {
515                    next: StateID::new(2).unwrap(),
516                    look: Look::Start,
517                }]
518            );
519        }
520
521        {
522            let closure = compute_epsilon_closure(StateID::new(11).unwrap(), re.states()).unwrap();
523
524            assert_eq!(
525                closure.unconditional,
526                [11].iter()
527                    .copied()
528                    .map(StateID::new)
529                    .map(Result::unwrap)
530                    .collect()
531            );
532            assert_eq!(
533                closure.lookaround,
534                vec![EpsilonLook {
535                    next: StateID::new(18).unwrap(),
536                    look: Look::End,
537                }]
538            );
539        }
540    }
541
542    fn compile_test_module(nfa: NFA) -> Vec<u8> {
543        let mut ctx = CompileContext::new(
544            nfa,
545            crate::Config::new()
546                .export_all_functions(true)
547                .export_state(true),
548        );
549        // Assume all tests use less than 255 states
550        assert_eq!(ctx.state_id_layout(), &Layout::new::<u8>());
551
552        let overall = Layout::new::<()>();
553        let (overall, sparse_set_layout) = SparseSetLayout::new(&mut ctx, overall).unwrap();
554        let (overall, look_layout) = LookLayout::new(&mut ctx, overall).unwrap();
555        let sparse_set_functions = SparseSetFunctions::new(&mut ctx, &sparse_set_layout);
556        let look_funcs = LookFunctions::new(&mut ctx, &look_layout).unwrap();
557
558        let _epsilon_closure_functions =
559            EpsilonClosureFunctions::new(&mut ctx, sparse_set_functions.insert, &look_funcs);
560
561        let module = ctx.compile(&overall);
562        module.finish()
563    }
564
565    fn setup_epsilon_closure_test(nfa: NFA, haystack: &[u8]) -> impl FnMut(i32, i64, &[u8]) + '_ {
566        let module_bytes = compile_test_module(nfa.clone());
567        let (_engine, _module, mut store, instance) = setup_interpreter(&module_bytes);
568        let branch_to_epsilon_closure = instance
569            .get_typed_func::<(i64, i64, i64, i64, i32, i32), i32>(
570                &store,
571                "branch_to_epsilon_closure",
572            )
573            .unwrap();
574
575        let state_memory = instance.get_memory(&store, "state").unwrap();
576        let haystack_memory = instance.get_memory(&store, "haystack").unwrap();
577
578        // Assuming that haystack starts at 0
579        haystack_memory.data_mut(&mut store)[0..haystack.len()].copy_from_slice(haystack);
580
581        move |state_id, at_offset: i64, expected_states: &[u8]| {
582            let haystack_ptr = 0;
583            let haystack_len = haystack.len() as i64;
584            // Would be safer if we passed the layout through and we read the set start
585            // position instead of assuming its at 0.
586            let set_ptr = 0;
587            let new_set_len = branch_to_epsilon_closure
588                .call(
589                    &mut store,
590                    (
591                        haystack_ptr,
592                        haystack_len,
593                        at_offset,
594                        set_ptr,
595                        0, /* set_len */
596                        state_id,
597                    ),
598                )
599                .unwrap();
600
601            let new_set_len = usize::try_from(new_set_len).unwrap();
602
603            assert_eq!(
604                new_set_len,
605                expected_states.len(),
606                "state [{state_id}] @ {at_offset}"
607            );
608            let epsilon_states = compute_epsilon_closure(
609                StateID::must(usize::try_from(state_id).unwrap()),
610                nfa.states(),
611            )
612            .unwrap();
613            assert!(
614                epsilon_states.unconditional.len() <= expected_states.len(),
615                "state [{state_id}] @ {at_offset}"
616            );
617
618            // Would be safer if we passed the layout through and we read the set start
619            // position instead of assuming its at 0.
620            let states = &unsafe { state_memory.data(&store).align_to::<u8>().1 }[0..new_set_len];
621            assert_eq!(states, expected_states, "state [{state_id}] @ {at_offset}");
622        }
623    }
624
625    #[test]
626    fn basic_epsilon_closure() {
627        // thompson::NFA(
628        // >000000: binary-union(2, 1)
629        //  000001: \x00-\xFF => 0
630        // ^000002: capture(pid=0, group=0, slot=0) => 3
631        //  000003: binary-union(4, 11)
632        //  000004: capture(pid=0, group=1, slot=2) => 5
633        //  000005: H => 6
634        //  000006: e => 7
635        //  000007: l => 8
636        //  000008: l => 9
637        //  000009: o => 10
638        //  000010: capture(pid=0, group=1, slot=3) => 3
639        //  000011: ' ' => 12
640        //  000012: w => 13
641        //  000013: o => 14
642        //  000014: r => 15
643        //  000015: l => 16
644        //  000016: d => 17
645        //  000017: capture(pid=0, group=0, slot=1) => 18
646        //  000018: MATCH(0)
647        let nfa = NFA::new("(Hello)* world").unwrap();
648
649        let mut test = setup_epsilon_closure_test(nfa, b"");
650
651        test(0, 0, &[0, 1, 2, 3, 4, 5, 11]);
652        test(3, 0, &[3, 4, 5, 11]);
653        test(4, 0, &[4, 5]);
654        test(5, 0, &[5]);
655    }
656
657    #[test]
658    fn simple_lookaround_epsilon_closure() {
659        // thompson::NFA(
660        // ^000000: capture(pid=0, group=0, slot=0) => 1
661        //  000001: Start => 2
662        //  000002: h => 3
663        //  000003: e => 4
664        //  000004: l => 5
665        //  000005: l => 6
666        //  000006: ' ' => 7
667        //  000007: w => 8
668        //  000008: o => 9
669        //  000009: r => 10
670        //  000010: m => 11
671        //  000011: End => 12
672        //  000012: capture(pid=0, group=0, slot=1) => 13
673        //  000013: MATCH(0)
674        let nfa = NFA::new("^hell worm$").unwrap();
675        let mut test = setup_epsilon_closure_test(nfa, b"hell worm");
676
677        // 2 state is reachable because we're at position 0 and the `Start` state
678        // matches
679        test(0, 0, &[0, 1, 2]);
680        // It doesn't match for this state
681        test(0, 1, &[0, 1]);
682
683        // Similarly, we get all the end state matches here
684        test(11, 9, &[11, 12, 13]);
685    }
686}