wahgex_core/compile/
matching.rs

1//! This module contains types and functions related to the actual PikeVM
2//! execution of `is_match`, `find`, `captures`, etc.
3
4use wasm_encoder::{BlockType, NameMap, ValType};
5
6use super::{
7    context::{
8        BlockSignature, CompileContext, Function, FunctionDefinition, FunctionIdx,
9        FunctionSignature, TypeIdx,
10    },
11    input::{InputFunctions, InputLayout},
12    state::{StateFunctions, StateLayout},
13};
14
15#[derive(Debug)]
16pub struct MatchingFunctions {
17    _is_match: FunctionIdx,
18}
19
20impl MatchingFunctions {
21    pub fn new(
22        ctx: &mut CompileContext,
23        state_layout: &StateLayout,
24        state_funcs: &StateFunctions,
25        input_layout: &InputLayout,
26        input_funcs: &InputFunctions,
27    ) -> Self {
28        let start_config_is_some_block_sig = ctx.add_block_signature(BlockSignature {
29            name: "start_config_is_some",
30            params_ty: &[ValType::I32, ValType::I32],
31            results_ty: &[ValType::I32, ValType::I32],
32        });
33
34        let is_match_block_sig = ctx.add_block_signature(BlockSignature {
35            name: "make_current_transitions_is_match",
36            params_ty: &[ValType::I32],
37            results_ty: &[],
38        });
39
40        let is_match = ctx.add_function(Self::is_match_fn(
41            state_layout,
42            state_funcs,
43            input_layout,
44            input_funcs,
45            start_config_is_some_block_sig,
46            is_match_block_sig,
47        ));
48
49        Self {
50            _is_match: is_match,
51        }
52    }
53
54    fn is_match_fn(
55        state_layout: &StateLayout,
56        state_funcs: &StateFunctions,
57        input_layout: &InputLayout,
58        input_funcs: &InputFunctions,
59        start_config_is_some_block_sig: TypeIdx,
60        is_match_block_sig: TypeIdx,
61    ) -> Function {
62        let mut locals_name_map = NameMap::new();
63        // Parameters
64        locals_name_map.append(0, "anchored");
65        locals_name_map.append(1, "anchored_pattern");
66        locals_name_map.append(2, "span_start");
67        locals_name_map.append(3, "span_end");
68        locals_name_map.append(4, "haystack_len");
69        // Locals
70        locals_name_map.append(5, "at_offset");
71        locals_name_map.append(6, "curr_set_ptr");
72        locals_name_map.append(7, "next_set_ptr");
73        locals_name_map.append(8, "curr_set_len");
74        locals_name_map.append(9, "next_set_len");
75        locals_name_map.append(10, "start_state_id");
76        locals_name_map.append(11, "is_anchored");
77
78        let mut labels_name_map = NameMap::new();
79        labels_name_map.append(1, "haystack_search_loop");
80
81        // Sketch:
82        // ```
83        // assert_input_args_wf(true, anchored, anchored_pattern, span_start, span_end, haystack_len)
84        // (start_state_id, is_anchored, is_some) = start_config(anchored, anchored_pattern)
85        // if !is_some {
86        //     return false;
87        // }
88        //
89        // curr_set_ptr = first_set_start_pos;
90        // curr_set_len = 0;
91        // next_set_ptr = second_set_start_pos;
92        // next_set_len = 0;
93        // at_offset = span_start;
94        // loop {
95        //     if at_offset > span_end {
96        //         return false;
97        //     }
98        //
99        //     if curr_set_len == 0 && is_anchored && at_offset > span_start {
100        //         return false;
101        //     }
102        //
103        //     if !is_anchored || at_offset == span_start {
104        //         curr_set_len = branch_to_epsilon_closure(haystack_ptr, haystack_len, at_offset, curr_set_ptr, curr_set_len, start_state_id)
105        //     }
106        //
107        //     new_next_set_len, is_match = make_current_transitions(haystack_ptr, haystack_len, at_offset, curr_set_ptr, curr_set_len, next_set_ptr, next_set_len)
108        //     if is_match && utf8_is_boundary(haystack_ptr, haystack_len, at_offset) {
109        //         return true;
110        //     }
111        //     curr_set_ptr, next_set_ptr = next_set_ptr, curr_set_ptr;
112        //     curr_set_len, next_set_len = next_set_len, curr_set_len;
113        //     next_set_len = 0;
114        //     at = at + 1;
115        // }
116        // ```
117
118        let mut body = wasm_encoder::Function::new([(3, ValType::I64), (4, ValType::I32)]);
119        body.instructions()
120            // assert_input_args_wf(true, anchored, anchored_pattern, span_start, span_end,
121            // haystack_len)
122            .i32_const(true as i32) // earliest
123            .local_get(0) // anchored
124            .local_get(1) // anchored_pattern
125            .local_get(2) // span_start
126            .local_get(3) // span_end
127            .local_get(4) // haystack_len
128            .call(input_funcs.assert_input_args_wf.into())
129            // (start_state_id, is_anchored, is_some) = start_config(anchored, anchored_pattern)
130            .local_get(0) // anchored
131            .local_get(1) // anchored_pattern
132            .call(input_funcs.start_config.into())
133            // if !is_some {
134            .i32_const(false as i32)
135            .i32_eq()
136            .if_(BlockType::FunctionType(
137                start_config_is_some_block_sig.into(),
138            ))
139            // return false;
140            .drop()
141            .drop()
142            .i32_const(false as i32)
143            .return_()
144            .end()
145            .local_set(11) // is_anchored
146            .local_set(10) // start_state_id
147            // curr_set_ptr = first_set_start_pos;
148            .i64_const(i64::from_ne_bytes(
149                u64::try_from(state_layout.first_sparse_set.set_start_pos)
150                    .unwrap()
151                    .to_ne_bytes(),
152            ))
153            .local_set(6) // curr_set_ptr
154            // next_set_ptr = second_set_start_pos;
155            .i64_const(i64::from_ne_bytes(
156                u64::try_from(state_layout.second_sparse_set.set_start_pos)
157                    .unwrap()
158                    .to_ne_bytes(),
159            ))
160            .local_set(7) // next_set_ptr
161            // at_offset = span_start
162            .local_get(2) // span_start
163            .local_set(5) // at_offset
164            // loop {
165            .loop_(BlockType::Empty)
166            // if at_offset > span_end {
167            .local_get(5) // at_offset
168            .local_get(3) // span_end
169            .i64_gt_u()
170            .if_(BlockType::Empty)
171            // return false;
172            .i32_const(false as i32)
173            .return_()
174            .end()
175            // if curr_set_len == 0 && is_anchored && at_offset > span_start {
176            .local_get(8) // curr_set_len
177            .i32_const(0)
178            .i32_eq()
179            .local_get(11) // is_anchored
180            .local_get(5) // at_offset
181            .local_get(2) // span_start
182            .i64_gt_u()
183            .i32_and()
184            .i32_and()
185            .if_(BlockType::Empty)
186            // return false;
187            .i32_const(false as i32)
188            .return_()
189            .end()
190            // if !is_anchored || at_offset == span_start {
191            .local_get(11) // is_anchored
192            .i32_const(false as i32)
193            .i32_eq()
194            .local_get(5) // at_offset
195            .local_get(2) // span_start
196            .i64_eq()
197            .i32_or()
198            .if_(BlockType::Empty)
199            // curr_set_len = branch_to_epsilon_closure(haystack_ptr, haystack_len, at_offset,
200            // curr_set_ptr, curr_set_len, start_state_id)
201            .i64_const(i64::from_ne_bytes(
202                u64::try_from(input_layout.haystack_start_pos)
203                    .unwrap()
204                    .to_ne_bytes(),
205            ))
206            .local_get(4) // haystack_len
207            .local_get(5) // at_offset
208            .local_get(6) // curr_set_ptr
209            .local_get(8) // curr_set_len
210            .local_get(10) // start_state_id
211            .call(state_funcs.epsilon_closure.branch_to_epsilon_closure.into())
212            .local_set(8) // curr_set_len
213            .end()
214            // new_next_set_len, is_match = make_current_transitions(haystack_ptr, haystack_len,
215            // at_offset, curr_set_ptr, curr_set_len, next_set_ptr, next_set_len)
216            .i64_const(i64::from_ne_bytes(
217                u64::try_from(input_layout.haystack_start_pos)
218                    .unwrap()
219                    .to_ne_bytes(),
220            ))
221            .local_get(4) // haystack_len
222            .local_get(5) // at_offset
223            .local_get(6) // curr_set_ptr
224            .local_get(8) // curr_set_len
225            .local_get(7) // next_set_ptr
226            .local_get(9) // next_set_len
227            .call(state_funcs.transition.make_current_transitions.into());
228
229        // stack: [new_next_set_len, is_match]
230        // if is_match && utf8_is_boundary(haystack_ptr, haystack_len, at_offset)
231
232        // This should only be `Some` if the input NFA can match the empty string and
233        // UTF-8 is enabled
234        if let Some(utf8_is_boundary) = input_funcs.utf8_is_boundary {
235            body.instructions()
236                // utf8_is_boundary(haystack_ptr, haystack_len, at_offset)
237                .i64_const(i64::from_ne_bytes(
238                    u64::try_from(input_layout.haystack_start_pos)
239                        .unwrap()
240                        .to_ne_bytes(),
241                ))
242                .local_get(4) // haystack_len
243                .local_get(5) // at_offset
244                .call(utf8_is_boundary.into())
245                .i32_and();
246        }
247
248        body.instructions()
249            .if_(BlockType::FunctionType(is_match_block_sig.into()))
250            .drop()
251            .i32_const(true as i32)
252            .return_()
253            .else_()
254            // next_set_len = new_next_set_len;
255            .local_set(9) // next_set_len
256            .end()
257            // curr_set_ptr, next_set_ptr = next_set_ptr, curr_set_ptr;
258            .local_get(6) // curr_set_ptr
259            .local_get(7) // next_set_ptr
260            .local_set(6)
261            .local_set(7)
262            // curr_set_len, next_set_len = next_set_len, curr_set_len;
263            .local_get(8) // curr_set_len
264            .local_get(9) // next_set_len
265            .local_set(8)
266            .local_set(9)
267            // next_set_len = 0;
268            .i32_const(0)
269            .local_set(9)
270            // at = at + 1;
271            .local_get(5) // at_offset
272            .i64_const(1)
273            .i64_add()
274            .local_set(5) // at_offset
275            .br(0) // continue loop
276            .end()
277            // } end loop
278            .i32_const(false as i32)
279            .end();
280
281        Function {
282            sig: FunctionSignature {
283                name: "is_match".into(),
284                // [anchored, anchored_pattern, span_start, span_end, haystack_len]
285                params_ty: &[
286                    ValType::I32,
287                    ValType::I32,
288                    ValType::I64,
289                    ValType::I64,
290                    ValType::I64,
291                ],
292                // [is_match]
293                results_ty: &[ValType::I32],
294                export: true,
295            },
296            def: FunctionDefinition {
297                body,
298                locals_name_map,
299                labels_name_map: Some(labels_name_map),
300                branch_hints: None,
301            },
302        }
303    }
304}