wahgex_core/compile/epsilon_closure.rs
1//! This module contains types and functions related to computing the epsilon
2//! closure of a given NFA state.
3
4use std::collections::{HashMap, HashSet};
5
6use regex_automata::{
7 nfa::thompson::State,
8 util::{look::Look, primitives::StateID},
9};
10use wasm_encoder::{BlockType, NameMap, ValType};
11
12use super::{
13 context::{Function, FunctionDefinition, FunctionIdx, FunctionSignature},
14 lookaround::LookFunctions,
15 BuildError, CompileContext,
16};
17
18/// This struct contains a map of functions that are the pre-computed epsilon
19/// closure for each NFA state.
20#[derive(Debug)]
21pub struct EpsilonClosureFunctions {
22 state_closures: HashMap<StateID, FunctionIdx>,
23 pub branch_to_epsilon_closure: FunctionIdx,
24}
25
26impl EpsilonClosureFunctions {
27 /// Create a new set of epsilon closure functions for the given input.
28 pub fn new(
29 ctx: &mut CompileContext,
30 sparse_set_insert: FunctionIdx,
31 look_funcs: &LookFunctions,
32 ) -> Result<Self, BuildError> {
33 let state_closures = Self::all_epsilon_closure_fns(ctx, sparse_set_insert, look_funcs)?;
34 let branch_to_epsilon_closure = ctx.add_function(Self::branch_to_epsilon_closure_fn(
35 &state_closures,
36 sparse_set_insert,
37 ));
38
39 Ok(Self {
40 state_closures,
41 branch_to_epsilon_closure,
42 })
43 }
44
45 fn all_epsilon_closure_fns(
46 ctx: &mut CompileContext,
47 sparse_set_insert: FunctionIdx,
48 look_funcs: &LookFunctions,
49 ) -> Result<HashMap<StateID, FunctionIdx>, BuildError> {
50 // NOTE: The indexes of the `states` array correspond to the `StateID` value.
51 let mut state_to_epsilon_closure_fn = HashMap::new();
52
53 let num_states = ctx.nfa.states().len();
54 for for_sid in (0..num_states).map(StateID::new).map(Result::unwrap) {
55 let states = ctx.nfa.states();
56 let closure = compute_epsilon_closure(for_sid, states)?;
57 if Self::can_omit_epsilon_closure(&closure, for_sid) {
58 continue;
59 }
60
61 let sig = Self::epsilon_closure_fn_sig(for_sid);
62 let func_idx = ctx.declare_function(sig);
63
64 state_to_epsilon_closure_fn.insert(for_sid, func_idx);
65 }
66
67 for (for_sid, func_idx) in &state_to_epsilon_closure_fn {
68 let states = ctx.nfa.states();
69 let closure = compute_epsilon_closure(*for_sid, states)?;
70 let def = Self::epsilon_closure_fn_def(
71 closure,
72 &state_to_epsilon_closure_fn,
73 sparse_set_insert,
74 look_funcs,
75 )?;
76 ctx.define_function(*func_idx, def);
77 }
78
79 Ok(state_to_epsilon_closure_fn)
80 }
81
82 /// Get the epsilon closure function for the given state ID, if present.
83 #[expect(dead_code)]
84 pub fn get(&self, sid: StateID) -> Option<FunctionIdx> {
85 self.state_closures.get(&sid).copied()
86 }
87
88 fn branch_to_epsilon_closure_fn(
89 epsilon_closures: &HashMap<StateID, FunctionIdx>,
90 sparse_set_insert: FunctionIdx,
91 ) -> Function {
92 let mut locals_name_map = NameMap::new();
93 // Parameters
94 locals_name_map.append(0, "haystack_ptr");
95 locals_name_map.append(1, "haystack_len");
96 locals_name_map.append(2, "at_offset");
97 locals_name_map.append(3, "next_set_ptr");
98 locals_name_map.append(4, "next_set_len");
99 locals_name_map.append(5, "state_id");
100
101 let mut body = wasm_encoder::Function::new([]);
102 let mut instructions = body.instructions();
103
104 let mut states = epsilon_closures.keys().copied().collect::<Vec<_>>();
105 states.sort();
106
107 // This loop will cover any states where [`Self::can_omit_epsilon_closure`]
108 // returned false. All other states will fall through to the code below
109 // which inserts only the self-state.
110 for sid in states {
111 let epsilon_closure_fn = epsilon_closures.get(&sid).copied().unwrap();
112 instructions
113 .local_get(5)
114 .i32_const(i32::from_ne_bytes(sid.as_u32().to_ne_bytes()))
115 .i32_eq()
116 .if_(BlockType::Empty)
117 .local_get(0)
118 .local_get(1)
119 .local_get(2)
120 .local_get(3)
121 .local_get(4)
122 .call(epsilon_closure_fn.into())
123 .return_()
124 .end();
125 }
126
127 // If it falls through to this point, then we must assume thats its a state
128 // which has no epsilon transitions. In which case, we need to add the current
129 // state to the next set and return.
130 instructions
131 .local_get(4) // next_set_len
132 .local_get(5) // state_id
133 .local_get(3) // next_set_ptr
134 .call(sparse_set_insert.into())
135 .end();
136
137 Function {
138 sig: FunctionSignature {
139 name: "branch_to_epsilon_closure".into(),
140 // [haystack_ptr, haystack_len, at_offset, next_set_ptr, next_set_len, state_id]
141 params_ty: &[
142 // TODO(opt): Remove haystack_ptr and assume that haystack always starts at
143 // offset 0 in memory 0
144 ValType::I64,
145 ValType::I64,
146 ValType::I64,
147 ValType::I64,
148 ValType::I32,
149 ValType::I32,
150 ],
151 // [new_next_set_len]
152 results_ty: &[ValType::I32],
153 export: false,
154 },
155 def: FunctionDefinition {
156 body,
157 locals_name_map,
158 labels_name_map: None,
159 branch_hints: None,
160 },
161 }
162 }
163
164 /// Return true if we can omit the epsilon closure function for the given
165 /// state and closure.
166 ///
167 /// We can omit epsilon closures which only contain the self-state, since
168 /// branch_to_epsilon_closure will always include a default branch to
169 /// populate the singleton set.
170 fn can_omit_epsilon_closure(closure: &EpsilonClosure, for_sid: StateID) -> bool {
171 closure.unconditional.len() == 1
172 && closure.unconditional.contains(&for_sid)
173 // Return false if there are conditional lookaround transitions from for_sid
174 && closure.lookaround.is_empty()
175 }
176
177 fn epsilon_closure_fn_sig(for_sid: StateID) -> FunctionSignature {
178 FunctionSignature {
179 name: format!("epsilon_closure_s{}", for_sid.as_usize()),
180 // [haystack_ptr, haystack_len, at_offset, next_set_ptr, next_set_len]
181 params_ty: &[
182 ValType::I64,
183 ValType::I64,
184 ValType::I64,
185 ValType::I64,
186 ValType::I32,
187 ],
188 // [new_next_set_len]
189 results_ty: &[ValType::I32],
190 export: false,
191 }
192 }
193
194 fn epsilon_closure_fn_def(
195 closure: EpsilonClosure,
196 state_to_epsilon_closure_fn: &HashMap<StateID, FunctionIdx>,
197 sparse_set_insert: FunctionIdx,
198 look_funcs: &LookFunctions,
199 ) -> Result<FunctionDefinition, BuildError> {
200 let mut unconditional = closure.unconditional.into_iter().collect::<Vec<_>>();
201 // need this to keep consistency of snapshot tests
202 unconditional.sort();
203
204 let mut locals_name_map = NameMap::new();
205 // Parameters
206 locals_name_map.append(0, "haystack_ptr");
207 locals_name_map.append(1, "haystack_len");
208 locals_name_map.append(2, "at_offset");
209 locals_name_map.append(3, "next_set_ptr");
210 locals_name_map.append(4, "next_set_len");
211 // Locals
212 locals_name_map.append(5, "new_next_set_len");
213
214 let mut body = wasm_encoder::Function::new([(1, ValType::I32)]);
215 let mut instructions = body.instructions();
216 // TODO: `haystack_ptr`, `haystack_len`, and `at_offset` will be unused until we
217 // support lookaround and need to check the stack
218
219 instructions.local_get(4);
220
221 // TODO(opt): Could optimize this by adding a bulk insert method and loading all
222 // of these from a memory location initialized by an active data segment
223 for closure_sid in unconditional {
224 instructions
225 // new_next_set_len is already on the stack from the prelude or the previous call to
226 // sparse_set_insert
227 .i32_const(i32::from_ne_bytes(closure_sid.as_u32().to_ne_bytes()))
228 .local_get(3) // next_set_ptr
229 // TODO(opt): Instead of creating a separate function for every state's epsilon
230 // transition, have some of them be inlined depending on size.
231 .call(sparse_set_insert.into());
232 }
233
234 // At this point the stack is [new_next_set_len]
235
236 // Implementation strategy for lookaround:
237 // 1. For epsilon transitions that include a `Look`, add a conditional block
238 // after inserting all the unconditional states. The block should be keyed
239 // on whether or not new states were added to the next_set.
240 // 2. Inside the block, we should have the actual `Look` conditionals, based on
241 // the haystack.
242 // 3. If the look conditional passes, then recurse into the epsilon closure
243 // function of the `next` state. If that function was omitted (see
244 // `can_omit_epsilon_closure`) then just emit some code that adds the `next`
245 // state to the `next_set`.
246
247 if !closure.lookaround.is_empty() {
248 instructions
249 .local_tee(5)
250 .local_get(4)
251 .i32_ne()
252 .if_(BlockType::Empty);
253 for look in closure.lookaround {
254 instructions
255 .local_get(0)
256 .local_get(1)
257 .local_get(2)
258 .call(look_funcs.look_matcher(look.look).unwrap().into())
259 .if_(BlockType::Empty);
260 // conditional look did match, now call into epsilon transition
261 if let Some(epsilon_closure_fn_idx) =
262 state_to_epsilon_closure_fn.get(&look.next).copied()
263 {
264 // Recursive call to the next state's epsilon closure fn
265 instructions
266 // Args needed [haystack_ptr, haystack_len, at_offset, next_set_ptr,
267 // new_next_set_len]
268 .local_get(0)
269 .local_get(1)
270 .local_get(2)
271 .local_get(3)
272 .local_get(5)
273 .call(epsilon_closure_fn_idx.into())
274 .local_set(5);
275 } else {
276 // Single state insert
277 instructions
278 // Args needed [new_next_set_len, state_id, next_set_ptr]
279 .local_get(5)
280 .i32_const(i32::from_ne_bytes(look.next.as_u32().to_ne_bytes()))
281 .local_get(3)
282 .call(sparse_set_insert.into())
283 .local_set(5);
284 }
285
286 instructions.end();
287 }
288
289 instructions.end().local_get(5);
290 }
291
292 instructions.end();
293
294 Ok(FunctionDefinition {
295 body,
296 locals_name_map,
297 labels_name_map: None,
298 branch_hints: None,
299 })
300 }
301}
302
303#[derive(Debug)]
304struct EpsilonClosure {
305 /// This is the set of states that are unconditionally epsilon-reachable.
306 ///
307 /// This is contrast to those states that are conditionally
308 /// epsilon-reachable through a [`State::Look`] (lookaround).
309 unconditional: HashSet<StateID>,
310 /// This is the list of lookaround states that are directly reachable from
311 /// the `pure` set with no conditional epsilon transitions.
312 lookaround: Vec<EpsilonLook>,
313}
314
315#[derive(Debug, Clone, Copy, PartialEq, Eq)]
316struct EpsilonLook {
317 next: StateID,
318 look: Look,
319}
320
321fn compute_epsilon_closure(sid: StateID, states: &[State]) -> Result<EpsilonClosure, BuildError> {
322 let mut unconditional: HashSet<_> = HashSet::new();
323
324 let mut lookaround = Vec::new();
325
326 let mut stack = vec![sid];
327 'stack: while let Some(mut sid) = stack.pop() {
328 loop {
329 if !unconditional.insert(sid) {
330 continue 'stack;
331 }
332
333 match &states[sid.as_usize()] {
334 State::Fail
335 | State::Match { .. }
336 | State::ByteRange { .. }
337 | State::Sparse { .. }
338 | State::Dense { .. } => {
339 // TODO: Need to integrate here for slot/matching support
340 continue 'stack;
341 },
342 State::Look { look, next } => {
343 lookaround.push(EpsilonLook {
344 next: *next,
345 look: *look,
346 });
347 },
348 State::Union { alternates } => {
349 sid = match alternates.first() {
350 None => continue 'stack,
351 Some(&sid) => sid,
352 };
353 stack.extend(alternates[1..].iter().copied().rev());
354 },
355 State::BinaryUnion { alt1, alt2 } => {
356 sid = *alt1;
357 stack.push(*alt2);
358 },
359 State::Capture { next, .. } => {
360 // TODO: Need to integrate here for slot/matching support
361 sid = *next;
362 },
363 }
364 }
365 }
366
367 Ok(EpsilonClosure {
368 unconditional,
369 lookaround,
370 })
371}
372
373#[cfg(test)]
374mod tests {
375 use std::alloc::Layout;
376
377 use regex_automata::nfa::thompson::NFA;
378
379 use crate::compile::{
380 lookaround::LookLayout,
381 sparse_set::{SparseSetFunctions, SparseSetLayout},
382 tests::setup_interpreter,
383 };
384
385 use super::*;
386
387 #[test]
388 fn test_epsilon_closures() {
389 let re = NFA::new("(Hello)* world").unwrap();
390 // thompson::NFA(
391 // >000000: binary-union(2, 1)
392 // 000001: \x00-\xFF => 0
393 // ^000002: capture(pid=0, group=0, slot=0) => 3
394 // 000003: binary-union(4, 11)
395 // 000004: capture(pid=0, group=1, slot=2) => 5
396 // 000005: H => 6
397 // 000006: e => 7
398 // 000007: l => 8
399 // 000008: l => 9
400 // 000009: o => 10
401 // 000010: capture(pid=0, group=1, slot=3) => 3
402 // 000011: ' ' => 12
403 // 000012: w => 13
404 // 000013: o => 14
405 // 000014: r => 15
406 // 000015: l => 16
407 // 000016: d => 17
408 // 000017: capture(pid=0, group=0, slot=1) => 18
409 // 000018: MATCH(0)
410
411 let test = |sid: StateID, expected_states: &[usize]| {
412 let closure = compute_epsilon_closure(sid, re.states()).unwrap();
413 assert_eq!(
414 closure.unconditional,
415 expected_states
416 .iter()
417 .copied()
418 .map(StateID::new)
419 .map(Result::unwrap)
420 .collect(),
421 "Closure from state {sid:?} on:\n{re:?}",
422 );
423 };
424
425 test(StateID::ZERO, &[2, 1, 4, 11, 3, 0, 5]);
426 test(StateID::new(3).unwrap(), &[3, 4, 5, 11]);
427 test(StateID::new(4).unwrap(), &[4, 5]);
428 test(StateID::new(5).unwrap(), &[5]);
429 }
430
431 #[test]
432 fn test_large_union_epsilon_closure() {
433 let re = NFA::new("a*|b*|c*|d*|e*").unwrap();
434 // thompson::NFA(
435 // >000000: binary-union(2, 1)
436 // 000001: \x00-\xFF => 0
437 // ^000002: capture(pid=0, group=0, slot=0) => 7
438 // 000003: binary-union(4, 14)
439 // 000004: a => 3
440 // 000005: binary-union(6, 14)
441 // 000006: b => 5
442 // 000007: union(3, 5, 8, 10, 12)
443 // 000008: binary-union(9, 14)
444 // 000009: c => 8
445 // 000010: binary-union(11, 14)
446 // 000011: d => 10
447 // 000012: binary-union(13, 14)
448 // 000013: e => 12
449 // 000014: capture(pid=0, group=0, slot=1) => 15
450 // 000015: MATCH(0)
451
452 let test = |sid: StateID, expected_states: &[usize]| {
453 let closure = compute_epsilon_closure(sid, re.states()).unwrap();
454 assert_eq!(
455 closure.unconditional,
456 expected_states
457 .iter()
458 .copied()
459 .map(StateID::new)
460 .map(Result::unwrap)
461 .collect(),
462 "Closure from state {sid:?} on:\n{re:?}",
463 );
464 };
465
466 test(
467 StateID::new(7).unwrap(),
468 &[3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15],
469 );
470 test(StateID::new(3).unwrap(), &[3, 4, 14, 15]);
471 test(StateID::new(4).unwrap(), &[4]);
472 test(StateID::new(14).unwrap(), &[14, 15]);
473 }
474
475 #[test]
476 fn lookaround_epsilon_closure_panic() {
477 let re = NFA::new(r"^hell (?:worm$|world)").unwrap();
478 // thompson::NFA(
479 // ^000000: capture(pid=0, group=0, slot=0) => 1
480 // 000001: Start => 2
481 // 000002: h => 3
482 // 000003: e => 4
483 // 000004: l => 5
484 // 000005: l => 6
485 // 000006: ' ' => 17
486 // 000007: w => 8
487 // 000008: o => 9
488 // 000009: r => 10
489 // 000010: m => 11
490 // 000011: End => 18
491 // 000012: w => 13
492 // 000013: o => 14
493 // 000014: r => 15
494 // 000015: l => 16
495 // 000016: d => 18
496 // 000017: binary-union(7, 12)
497 // 000018: capture(pid=0, group=0, slot=1) => 19
498 // 000019: MATCH(0)
499
500 {
501 let closure = compute_epsilon_closure(StateID::new(0).unwrap(), re.states()).unwrap();
502
503 assert_eq!(
504 closure.unconditional,
505 [0, 1]
506 .iter()
507 .copied()
508 .map(StateID::new)
509 .map(Result::unwrap)
510 .collect()
511 );
512 assert_eq!(
513 closure.lookaround,
514 vec![EpsilonLook {
515 next: StateID::new(2).unwrap(),
516 look: Look::Start,
517 }]
518 );
519 }
520
521 {
522 let closure = compute_epsilon_closure(StateID::new(11).unwrap(), re.states()).unwrap();
523
524 assert_eq!(
525 closure.unconditional,
526 [11].iter()
527 .copied()
528 .map(StateID::new)
529 .map(Result::unwrap)
530 .collect()
531 );
532 assert_eq!(
533 closure.lookaround,
534 vec![EpsilonLook {
535 next: StateID::new(18).unwrap(),
536 look: Look::End,
537 }]
538 );
539 }
540 }
541
542 fn compile_test_module(nfa: NFA) -> Vec<u8> {
543 let mut ctx = CompileContext::new(
544 nfa,
545 crate::Config::new()
546 .export_all_functions(true)
547 .export_state(true),
548 );
549 // Assume all tests use less than 255 states
550 assert_eq!(ctx.state_id_layout(), &Layout::new::<u8>());
551
552 let overall = Layout::new::<()>();
553 let (overall, sparse_set_layout) = SparseSetLayout::new(&mut ctx, overall).unwrap();
554 let (overall, look_layout) = LookLayout::new(&mut ctx, overall).unwrap();
555 let sparse_set_functions = SparseSetFunctions::new(&mut ctx, &sparse_set_layout);
556 let look_funcs = LookFunctions::new(&mut ctx, &look_layout).unwrap();
557
558 let _epsilon_closure_functions =
559 EpsilonClosureFunctions::new(&mut ctx, sparse_set_functions.insert, &look_funcs);
560
561 let module = ctx.compile(&overall);
562 module.finish()
563 }
564
565 fn setup_epsilon_closure_test(nfa: NFA, haystack: &[u8]) -> impl FnMut(i32, i64, &[u8]) + '_ {
566 let module_bytes = compile_test_module(nfa.clone());
567 let (_engine, _module, mut store, instance) = setup_interpreter(&module_bytes);
568 let branch_to_epsilon_closure = instance
569 .get_typed_func::<(i64, i64, i64, i64, i32, i32), i32>(
570 &store,
571 "branch_to_epsilon_closure",
572 )
573 .unwrap();
574
575 let state_memory = instance.get_memory(&store, "state").unwrap();
576 let haystack_memory = instance.get_memory(&store, "haystack").unwrap();
577
578 // Assuming that haystack starts at 0
579 haystack_memory.data_mut(&mut store)[0..haystack.len()].copy_from_slice(haystack);
580
581 move |state_id, at_offset: i64, expected_states: &[u8]| {
582 let haystack_ptr = 0;
583 let haystack_len = haystack.len() as i64;
584 // Would be safer if we passed the layout through and we read the set start
585 // position instead of assuming its at 0.
586 let set_ptr = 0;
587 let new_set_len = branch_to_epsilon_closure
588 .call(
589 &mut store,
590 (
591 haystack_ptr,
592 haystack_len,
593 at_offset,
594 set_ptr,
595 0, /* set_len */
596 state_id,
597 ),
598 )
599 .unwrap();
600
601 let new_set_len = usize::try_from(new_set_len).unwrap();
602
603 assert_eq!(
604 new_set_len,
605 expected_states.len(),
606 "state [{state_id}] @ {at_offset}"
607 );
608 let epsilon_states = compute_epsilon_closure(
609 StateID::must(usize::try_from(state_id).unwrap()),
610 nfa.states(),
611 )
612 .unwrap();
613 assert!(
614 epsilon_states.unconditional.len() <= expected_states.len(),
615 "state [{state_id}] @ {at_offset}"
616 );
617
618 // Would be safer if we passed the layout through and we read the set start
619 // position instead of assuming its at 0.
620 let states = &unsafe { state_memory.data(&store).align_to::<u8>().1 }[0..new_set_len];
621 assert_eq!(states, expected_states, "state [{state_id}] @ {at_offset}");
622 }
623 }
624
625 #[test]
626 fn basic_epsilon_closure() {
627 // thompson::NFA(
628 // >000000: binary-union(2, 1)
629 // 000001: \x00-\xFF => 0
630 // ^000002: capture(pid=0, group=0, slot=0) => 3
631 // 000003: binary-union(4, 11)
632 // 000004: capture(pid=0, group=1, slot=2) => 5
633 // 000005: H => 6
634 // 000006: e => 7
635 // 000007: l => 8
636 // 000008: l => 9
637 // 000009: o => 10
638 // 000010: capture(pid=0, group=1, slot=3) => 3
639 // 000011: ' ' => 12
640 // 000012: w => 13
641 // 000013: o => 14
642 // 000014: r => 15
643 // 000015: l => 16
644 // 000016: d => 17
645 // 000017: capture(pid=0, group=0, slot=1) => 18
646 // 000018: MATCH(0)
647 let nfa = NFA::new("(Hello)* world").unwrap();
648
649 let mut test = setup_epsilon_closure_test(nfa, b"");
650
651 test(0, 0, &[0, 1, 2, 3, 4, 5, 11]);
652 test(3, 0, &[3, 4, 5, 11]);
653 test(4, 0, &[4, 5]);
654 test(5, 0, &[5]);
655 }
656
657 #[test]
658 fn simple_lookaround_epsilon_closure() {
659 // thompson::NFA(
660 // ^000000: capture(pid=0, group=0, slot=0) => 1
661 // 000001: Start => 2
662 // 000002: h => 3
663 // 000003: e => 4
664 // 000004: l => 5
665 // 000005: l => 6
666 // 000006: ' ' => 7
667 // 000007: w => 8
668 // 000008: o => 9
669 // 000009: r => 10
670 // 000010: m => 11
671 // 000011: End => 12
672 // 000012: capture(pid=0, group=0, slot=1) => 13
673 // 000013: MATCH(0)
674 let nfa = NFA::new("^hell worm$").unwrap();
675 let mut test = setup_epsilon_closure_test(nfa, b"hell worm");
676
677 // 2 state is reachable because we're at position 0 and the `Start` state
678 // matches
679 test(0, 0, &[0, 1, 2]);
680 // It doesn't match for this state
681 test(0, 1, &[0, 1]);
682
683 // Similarly, we get all the end state matches here
684 test(11, 9, &[11, 12, 13]);
685 }
686}