wahgex_core/compile/
input.rs

1//! This module contains types and functions related to laying out the input
2//! options and haystack in the WASM memory.
3
4use std::alloc::{Layout, LayoutError};
5
6use regex_automata::{nfa::thompson::NFA, Anchored};
7use wasm_encoder::{BlockType, NameMap, ValType};
8
9use super::context::{
10    BlockSignature, CompileContext, Function, FunctionDefinition, FunctionIdx, FunctionSignature,
11    TypeIdx,
12};
13
14/// This type is a mirror of [`regex_automata::Input`], with guaranteed
15/// alignment and no-substructs.
16#[derive(Debug)]
17#[repr(C)]
18pub struct InputOpts {
19    /// Whether to execute an "earliest" search or not.
20    pub earliest: i32,
21    /// Sets the anchor mode of a search.
22    ///
23    /// The translation:
24    ///   - [`Anchored::No`] => `0`
25    ///   - [`Anchored::Yes`] => `1`
26    ///   - [`Anchored::Pattern`] => `2`
27    pub anchored: i32,
28    /// If `anchored` is equivalent to [`Anchored::Pattern`], then this is the
29    /// [`PatternID`][regex_automata::util::primitives::PatternID].
30    ///
31    /// Otherwise, it is set to 0.
32    pub anchored_pattern: i32,
33}
34
35impl InputOpts {
36    /// Creates a new `InputOpts` from a [`regex_automata::Input`].
37    ///
38    /// This translates the anchor mode and earliest flag into i32 values
39    /// suitable for WASM.
40    pub fn new(input: &regex_automata::Input<'_>) -> InputOpts {
41        let (anchored, anchored_pattern) = match input.get_anchored() {
42            Anchored::No => (0, 0),
43            Anchored::Yes => (1, 0),
44            Anchored::Pattern(id) => (2, i32::from_ne_bytes(id.to_ne_bytes())),
45        };
46
47        InputOpts {
48            earliest: input.get_earliest() as i32,
49            anchored,
50            anchored_pattern,
51        }
52    }
53}
54
55/// Defines the memory layout for input-related data within the WebAssembly
56/// module.
57///
58/// This includes the starting position of the haystack.
59#[derive(Debug)]
60pub struct InputLayout {
61    pub haystack_start_pos: usize,
62    _overall: Layout,
63}
64
65impl InputLayout {
66    /// Creates a new [`InputLayout`].
67    ///
68    /// Currently, this primarily determines the starting offset for the
69    /// haystack.
70    pub fn new(_ctx: &mut CompileContext) -> Result<Self, LayoutError> {
71        let overall = Layout::new::<()>();
72
73        // We use a zero-size array here to mark the start of the haystack, since we
74        // don't know the length of it until runtime.
75        let (overall, haystack_start_pos) = overall.extend(Layout::array::<u8>(0)?)?;
76
77        Ok(Self {
78            _overall: overall,
79            haystack_start_pos,
80        })
81    }
82}
83
84/// Holds indices to WebAssembly functions related to input processing.
85///
86/// These functions are used by the compiled regex to manage and interpret the
87/// input haystack.
88#[derive(Debug)]
89pub struct InputFunctions {
90    #[expect(dead_code)]
91    prepare_input: FunctionIdx,
92    pub assert_input_args_wf: FunctionIdx,
93    pub utf8_is_boundary: Option<FunctionIdx>,
94    pub start_config: FunctionIdx,
95}
96
97impl InputFunctions {
98    /// Creates and registers the necessary WebAssembly functions for input
99    /// handling.
100    ///
101    /// This includes functions for preparing input memory, asserting argument
102    /// well-formedness, checking UTF-8 boundaries, and configuring start
103    /// conditions.
104    pub fn new(
105        ctx: &mut CompileContext,
106        input_layout: &InputLayout,
107        pattern_lookup_start: FunctionIdx,
108    ) -> Self {
109        let prepare_input = ctx.add_function(Self::prepare_input_fn(
110            ctx.config.get_page_size(),
111            input_layout,
112        ));
113        let assert_input_args_wf = ctx.add_function(Self::assert_input_args_wf_fn(
114            ctx.config.get_page_size(),
115            input_layout,
116        ));
117
118        let utf8_is_boundary = (ctx.nfa.has_empty() && ctx.nfa.is_utf8())
119            .then(|| ctx.add_function(Self::utf8_is_boundary_fn()));
120
121        let pattern_lookup_start_result_block_sig = ctx.add_block_signature(BlockSignature {
122            name: "pattern_lookup_start_result",
123            params_ty: &[ValType::I32],
124            results_ty: &[],
125        });
126
127        let start_config = ctx.add_function(Self::start_config_fn(
128            &ctx.nfa,
129            pattern_lookup_start,
130            pattern_lookup_start_result_block_sig,
131        ));
132
133        Self {
134            prepare_input,
135            assert_input_args_wf,
136            utf8_is_boundary,
137            start_config,
138        }
139    }
140
141    fn start_config_fn(
142        nfa: &NFA,
143        pattern_lookup_start: FunctionIdx,
144        pattern_lookup_start_result_block_sig: TypeIdx,
145    ) -> Function {
146        // Copied from https://github.com/rust-lang/regex/blob/1a069b9232c607b34c4937122361aa075ef573fa/regex-automata/src/nfa/thompson/pikevm.rs#L1751-L1785
147
148        let mut locals_name_map = NameMap::new();
149        // Parameters
150        locals_name_map.append(0, "anchored");
151        locals_name_map.append(1, "anchored_pattern");
152        // Locals
153        locals_name_map.append(2, "pattern_start");
154
155        // Sketch:
156        // ```rust
157        // if anchored == Anchored::No {
158        //     return (nfa.start_anchored(), nfa.is_always_start_anchored(), true);
159        // }
160        // if anchored == Anchored::Yes {
161        //     return (nfa.start_anchored(), true, true);
162        // }
163        // if anchored == Anchored::Pattern {
164        //     (pattern_start, is_some) = pattern_lookup_start(anchored_pattern);
165        //     if is_some {
166        //         return (pattern_start, true, true);
167        //     }
168        // }
169        // return (0, 0, false);
170        // ```
171
172        let mut body = wasm_encoder::Function::new([(1, ValType::I32)]);
173        body.instructions()
174            // if anchored == Anchored::No {
175            .local_get(0)
176            .i32_const(0) // Anchored::No
177            .i32_eq()
178            .if_(BlockType::Empty)
179            //  return (nfa.start_anchored(), nfa.is_always_start_anchored(), true);
180            .i32_const(i32::from_ne_bytes(nfa.start_anchored().to_ne_bytes()))
181            .i32_const(nfa.is_always_start_anchored() as i32)
182            .i32_const(true as i32)
183            .return_()
184            .end()
185            // if anchored == Anchored::Yes {
186            .local_get(0)
187            .i32_const(1) // Anchored::Yes
188            .i32_eq()
189            .if_(BlockType::Empty)
190            //  return (nfa.start_anchored(), true, true);
191            .i32_const(i32::from_ne_bytes(nfa.start_anchored().to_ne_bytes()))
192            .i32_const(true as i32)
193            .i32_const(true as i32)
194            .return_()
195            .end()
196            // if anchored == Anchored::Pattern {
197            .local_get(0)
198            .i32_const(2) // Anchored::Pattern
199            .i32_eq()
200            .if_(BlockType::Empty)
201            // (pattern_start, is_some) = pattern_lookup_start(anchored_pattern);
202            .local_get(1)
203            .call(pattern_lookup_start.into())
204            .if_(BlockType::FunctionType(
205                pattern_lookup_start_result_block_sig.into(),
206            ))
207            //  return (pattern_start, true, true);
208            .i32_const(true as i32)
209            .i32_const(true as i32)
210            .return_()
211            .else_()
212            .drop()
213            .end()
214            .end()
215            // return (0, 0, false);
216            .i32_const(0)
217            .i32_const(0)
218            .i32_const(false as i32)
219            .end();
220
221        Function {
222            sig: FunctionSignature {
223                name: "start_config".into(),
224                // [anchored, anchored_pattern]
225                params_ty: &[ValType::I32, ValType::I32],
226                // [start_state_id, is_anchored, is_some]
227                results_ty: &[ValType::I32, ValType::I32, ValType::I32],
228                export: false,
229            },
230            def: FunctionDefinition {
231                body,
232                locals_name_map,
233                labels_name_map: None,
234                branch_hints: None,
235            },
236        }
237    }
238
239    fn utf8_is_boundary_fn() -> Function {
240        // Copied from https://github.com/rust-lang/regex/blob/1a069b9232c607b34c4937122361aa075ef573fa/regex-automata/src/util/utf8.rs#L117-L137
241
242        // Sketch:
243        // ```rust
244        // if at_offset >= haystack_len {
245        //     return at_offset == haystack_len;
246        // }
247        //
248        // byte = haystack_ptr[at_offset];
249        // return (byte <= 0b0111_1111 || byte >= 0b1100_0000);
250        // ```
251
252        let mut locals_name_map = NameMap::new();
253        // Parameters
254        locals_name_map.append(0, "haystack_ptr");
255        locals_name_map.append(1, "haystack_len");
256        locals_name_map.append(2, "at_offset");
257        // Locals
258        locals_name_map.append(3, "byte");
259
260        let mut body = wasm_encoder::Function::new([(1, ValType::I32)]);
261        body.instructions()
262            // if at_offset >= haystack_len {
263            .local_get(2)
264            .local_get(1)
265            .i64_ge_u()
266            .if_(BlockType::Empty)
267            // return at_offset == haystack_len
268            .local_get(2)
269            .local_get(1)
270            .i64_eq() // returns either 0 or 1 as i32
271            .return_()
272            .end()
273            // byte = haystack_ptr[at_offset];
274            .local_get(0)
275            .local_get(2)
276            .i64_add()
277            .i32_load8_u(wasm_encoder::MemArg {
278                offset: 0,       // no compile-time offset
279                align: 0,        // align of 1 since we're loading a byte
280                memory_index: 0, // loading from haystack
281            })
282            .local_set(3)
283            // return (byte <= 0b0111_1111 || byte >= 0b1100_0000);
284            .local_get(3)
285            .i32_const(0b0111_1111)
286            .i32_le_u()
287            .local_get(3)
288            .i32_const(0b1100_0000)
289            .i32_ge_u()
290            .i32_or()
291            .end();
292
293        Function {
294            sig: FunctionSignature {
295                name: "utf8_is_boundary".into(),
296                // [haystack_ptr, haystack_len, at_offset]
297                params_ty: &[ValType::I64, ValType::I64, ValType::I64],
298                // [is_boundary]
299                results_ty: &[ValType::I32],
300                export: false,
301            },
302            def: FunctionDefinition {
303                body,
304                locals_name_map,
305                labels_name_map: None,
306                branch_hints: None,
307            },
308        }
309    }
310
311    fn assert_input_args_wf_fn(page_size: usize, input_layout: &InputLayout) -> Function {
312        let mut locals_name_map = NameMap::new();
313        // Parameters
314        locals_name_map.append(0, "earliest");
315        locals_name_map.append(1, "anchored");
316        locals_name_map.append(2, "anchored_pattern");
317        locals_name_map.append(3, "span_start");
318        locals_name_map.append(4, "span_end");
319        locals_name_map.append(5, "haystack_len");
320
321        let mut labels_name_map = NameMap::new();
322        labels_name_map.append(0, "check_earliest_wf");
323        labels_name_map.append(1, "check_anchored_wf");
324        labels_name_map.append(2, "check_anchored_pattern_wf");
325        labels_name_map.append(3, "check_span_start_wf");
326        labels_name_map.append(4, "check_span_end_wf");
327
328        let mut body = wasm_encoder::Function::new([]);
329        body.instructions()
330            // if earliest != true && earliest != false {
331            .local_get(0)
332            .i32_const(true as i32)
333            .i32_ne()
334            .local_get(0)
335            .i32_const(false as i32)
336            .i32_ne()
337            .i32_and()
338            .if_(BlockType::Empty)
339            .unreachable()
340            .end()
341            // if anchored != 0 && anchored != 1 && anchored != 2 {
342            .local_get(1)
343            .i32_const(0)
344            .i32_ne()
345            .local_get(1)
346            .i32_const(1)
347            .i32_ne()
348            .local_get(1)
349            .i32_const(2)
350            .i32_ne()
351            .i32_and()
352            .i32_and()
353            .if_(BlockType::Empty)
354            .unreachable()
355            .end()
356            // if anchored != 2 && anchored_pattern != 0 {
357            .local_get(1)
358            .i32_const(2)
359            .i32_ne()
360            .local_get(2)
361            .i32_const(0)
362            .i32_ne()
363            .i32_and()
364            .if_(BlockType::Empty)
365            .unreachable()
366            .end()
367            // if span_start > span_end {
368            .local_get(3)
369            .local_get(4)
370            .i64_gt_u()
371            .if_(BlockType::Empty)
372            .unreachable()
373            .end()
374            // if span_end >= haystack_len {
375            .local_get(4)
376            .local_get(5)
377            .i64_gt_u()
378            .if_(BlockType::Empty)
379            .unreachable()
380            .end()
381            // if haystack_start_pos + haystack_len > memory.size * page_size
382            .i64_const(i64::from_ne_bytes(
383                u64::try_from(input_layout.haystack_start_pos)
384                    .unwrap()
385                    .to_ne_bytes(),
386            ))
387            .local_get(5)
388            .i64_add()
389            .memory_size(0)
390            .i64_const(i64::from_ne_bytes(
391                u64::try_from(page_size).unwrap().to_ne_bytes(),
392            ))
393            .i64_mul()
394            .i64_gt_u()
395            .if_(BlockType::Empty)
396            .unreachable()
397            .end()
398            .end();
399
400        Function {
401            sig: FunctionSignature {
402                name: "assert_input_args_wf".into(),
403                params_ty: &[
404                    ValType::I32,
405                    ValType::I32,
406                    ValType::I32,
407                    ValType::I64,
408                    ValType::I64,
409                    ValType::I64,
410                ],
411                results_ty: &[],
412                export: false,
413            },
414            def: FunctionDefinition {
415                body,
416                locals_name_map,
417                labels_name_map: Some(labels_name_map),
418                branch_hints: None,
419            },
420        }
421    }
422
423    fn prepare_input_fn(page_size: usize, input_layout: &InputLayout) -> Function {
424        let mut locals_name_map = NameMap::new();
425        // Parameters
426        locals_name_map.append(0, "haystack_len");
427        // Locals
428        locals_name_map.append(1, "num_new_page_required");
429
430        let mut body = wasm_encoder::Function::new([(1, ValType::I64)]);
431        body.instructions()
432            // if haystack_len == 0 {
433            .local_get(0) // haystack_len
434            .i64_const(i64::from_ne_bytes(
435                u64::try_from(input_layout.haystack_start_pos)
436                    .unwrap()
437                    .to_ne_bytes(),
438            ))
439            .i64_add()
440            .i64_const(0)
441            .i64_eq()
442            .if_(BlockType::Empty)
443            // return SuccessNoGrowth
444            .i32_const(PrepareInputResult::SuccessNoGrowth as i32)
445            .return_()
446            .end()
447            // memory_grow = ((haystack_len + haystack_start_pos - 1) / page_size) + 1 - memory_size
448            .local_get(0) // haystack_len
449            .i64_const(i64::from_ne_bytes(
450                u64::try_from(input_layout.haystack_start_pos)
451                    .unwrap()
452                    .to_ne_bytes(),
453            ))
454            .i64_add()
455            .i64_const(1)
456            .i64_sub()
457            .i64_const(i64::from_ne_bytes(
458                u64::try_from(page_size).unwrap().to_ne_bytes(),
459            ))
460            .i64_div_u()
461            .i64_const(1)
462            .i64_add()
463            .memory_size(0)
464            .i64_sub()
465            .local_tee(1)
466            .i64_const(0)
467            // Use signed comparison: num_new_page_required > 0
468            // otherwise negative values of num_new_page_required would register as very large
469            // positive numbers
470            .i64_gt_s()
471            .if_(BlockType::Result(ValType::I32))
472            .local_get(1)
473            .memory_grow(0)
474            .i64_const(-1)
475            .i64_eq()
476            .if_(BlockType::Empty)
477            // If the memory.grow returns -1, then trap since I don't want to handle this
478            .i32_const(PrepareInputResult::Failure as i32)
479            .return_()
480            .end()
481            .i32_const(PrepareInputResult::SuccessGrowth as i32)
482            .else_()
483            .i32_const(PrepareInputResult::SuccessNoGrowth as i32)
484            .end()
485            .end();
486
487        Function {
488            sig: FunctionSignature {
489                name: "prepare_input".into(),
490                // [haystack_len]
491                params_ty: &[ValType::I64],
492                // [prepare_input_result]
493                results_ty: &[ValType::I32],
494                export: true,
495            },
496            def: FunctionDefinition {
497                body,
498                locals_name_map,
499                labels_name_map: None,
500                branch_hints: None,
501            },
502        }
503    }
504}
505
506/// This enum represents the results of the `prepare_input` function.
507#[derive(Debug)]
508pub enum PrepareInputResult {
509    /// Indicates that the input preparation was successful and no memory growth
510    /// was needed.
511    SuccessNoGrowth = 0,
512    /// Indicates that the input preparation was successful and memory was grown
513    /// to accommodate the haystack.
514    SuccessGrowth = 1,
515    /// Indicates that input preparation failed, likely due to an inability to
516    /// grow memory.
517    Failure = 2,
518}
519
520#[cfg(test)]
521mod tests {
522    use regex_automata::nfa::thompson::NFA;
523
524    use crate::compile::{
525        pattern::{PatternFunctions, PatternLayout},
526        tests::setup_interpreter,
527    };
528
529    use super::*;
530
531    #[test]
532    fn prepare_input() {
533        let mut ctx = CompileContext::new(
534            NFA::always_match(),
535            crate::Config::new()
536                .export_all_functions(true)
537                .export_state(true),
538        );
539
540        let state_overall = Layout::new::<()>();
541        let (state_overall, pattern_layout) = PatternLayout::new(&mut ctx, state_overall).unwrap();
542        let pattern_functions = PatternFunctions::new(&mut ctx, &pattern_layout);
543
544        let input_layout = InputLayout::new(&mut ctx).unwrap();
545        let _input_functions =
546            InputFunctions::new(&mut ctx, &input_layout, pattern_functions.lookup_start);
547        let page_size = ctx.config.get_page_size();
548
549        let module = ctx.compile(&state_overall);
550        let module_bytes = module.finish();
551        let (_engine, _module, mut store, instance) = setup_interpreter(&module_bytes);
552        let haystack_memory = instance.get_memory(&store, "haystack").unwrap();
553        let prepare_input = instance
554            .get_typed_func::<i64, i32>(&store, "prepare_input")
555            .unwrap();
556
557        let haystack_size = haystack_memory.size(&store);
558        assert_eq!(haystack_size, 1);
559
560        let haystack_len = 0;
561        let res = prepare_input.call(&mut store, haystack_len).unwrap();
562        assert_eq!(res, PrepareInputResult::SuccessNoGrowth as i32);
563
564        let haystack_size = haystack_memory.size(&store);
565        assert_eq!(haystack_size, 1);
566
567        let haystack_len = 1;
568        let res = prepare_input.call(&mut store, haystack_len).unwrap();
569        assert_eq!(res, PrepareInputResult::SuccessNoGrowth as i32);
570
571        let haystack_size = haystack_memory.size(&store);
572        assert_eq!(haystack_size, 1);
573
574        // This haystack_len should fill the entire extent of the default-sized haystack
575        // memory
576        let haystack_len = i64::try_from(page_size - input_layout._overall.size()).unwrap();
577        let res = prepare_input.call(&mut store, haystack_len).unwrap();
578        assert_eq!(res, PrepareInputResult::SuccessNoGrowth as i32);
579
580        let haystack_size = haystack_memory.size(&store);
581        assert_eq!(haystack_size, 1);
582
583        // This haystack_len should cause the haystack memory to increase by 1 page size
584        let haystack_len =
585            i64::try_from(page_size - input_layout._overall.size() + page_size).unwrap();
586        let res = prepare_input.call(&mut store, haystack_len).unwrap();
587        assert_eq!(res, PrepareInputResult::SuccessGrowth as i32);
588
589        let haystack_size = haystack_memory.size(&store);
590        assert_eq!(haystack_size, 2);
591
592        // Test case: num_new_page_required is negative
593        // At this point, memory has 2 pages.
594        // We'll request a haystack_len that only requires 1 page.
595        // input_layout.haystack_start_pos is 0.
596        // So, total_bytes_needed = haystack_len_for_negative_case.
597        // If haystack_len_for_negative_case = 1, then required_total_pages =
598        // ceil_div(1, page_size) = 1. num_new_page_required =
599        // required_total_pages (1) - current_pages (2) = -1. The function
600        // should return SuccessNoGrowth and memory should remain at 2 pages.
601        let haystack_len_for_negative_case = 1_i64; // Fits in 1 page
602        let res = prepare_input
603            .call(&mut store, haystack_len_for_negative_case)
604            .unwrap();
605        assert_eq!(
606            res,
607            PrepareInputResult::SuccessNoGrowth as i32,
608            "Should be SuccessNoGrowth when current pages > required pages"
609        );
610        assert_eq!(
611            haystack_memory.size(&store),
612            2,
613            "Memory size should remain 2 pages"
614        );
615    }
616}