wahgex_core/compile/
sparse_set.rs

1/*!
2This module defines a sparse set data structure. Its most interesting
3properties are:
4
5* They preserve insertion order.
6* Set membership testing is done in constant time.
7* Set insertion is done in constant time.
8* Clearing the set is done in constant time.
9
10The cost for doing this is that the capacity of the set needs to be known up
11front, and the elements in the set are limited to state identifiers.
12
13These sets are principally used when traversing an NFA state graph. This
14happens at search time, for example, in the PikeVM. It also happens during DFA
15determinization.
16
17Copied above documentation from <https://github.com/rust-lang/regex/blob/master/regex-automata/src/util/sparse_set.rs>
18and based my implementation off the same file.
19*/
20
21use std::alloc::{Layout, LayoutError};
22
23use wasm_encoder::{BlockType, NameMap, ValType};
24
25use crate::util::repeat;
26
27use super::{
28    context::{Function, FunctionDefinition, FunctionIdx, FunctionSignature},
29    instructions::InstructionSinkExt,
30    CompileContext,
31};
32
33/// This struct describes the layout of a "sparse set", which is used to
34/// track NFA state ID membership.
35///
36/// This type has methods which will generate WASM functions that operate on the
37/// set in WASM memory.
38#[derive(Debug)]
39pub struct SparseSetLayout {
40    #[cfg_attr(not(test), expect(dead_code))]
41    dense_layout: Layout,
42    #[cfg_attr(not(test), expect(dead_code))]
43    dense_stride: usize,
44    #[cfg_attr(not(test), expect(dead_code))]
45    sparse_layout: Layout,
46    #[cfg_attr(not(test), expect(dead_code))]
47    sparse_stride: usize,
48
49    #[cfg_attr(not(test), expect(dead_code))]
50    pub set_overall: Layout,
51    pub set_start_pos: usize,
52    sparse_array_offset: usize,
53
54    state_id_layout: Layout,
55}
56
57impl SparseSetLayout {
58    /// Create a new sparse set layout for the given `NFA`.
59    ///
60    /// The sparse set will be scaled to the number of states in the NFA.
61    pub fn new(ctx: &mut CompileContext, overall: Layout) -> Result<(Layout, Self), LayoutError> {
62        let num_states = ctx.nfa.states().len();
63
64        Self::with_num_states(num_states, overall, ctx.state_id_layout())
65    }
66
67    fn with_num_states(
68        num_states: usize,
69        overall: Layout,
70        state_id_layout: &Layout,
71    ) -> Result<(Layout, Self), LayoutError> {
72        // First field: `dense` - an array of length `num_state`, that contains the
73        // state IDs in the order they were inserted
74        let (dense_layout, dense_stride) = repeat(state_id_layout, num_states)?;
75
76        // Second field: `sparse` - an array of length `num_state`, that contains the
77        // state IDs in the order they were inserted
78        let (sparse_layout, sparse_stride) = repeat(state_id_layout, num_states)?;
79
80        let (set_overall, sparse_array_offset) = dense_layout.extend(sparse_layout)?;
81        let (overall, set_start_pos) = overall.extend(set_overall)?;
82
83        // The `len` field, which would normally be first, is going to be passed around
84        // by function parameter
85
86        let state_id_layout = *state_id_layout;
87
88        Ok((
89            overall,
90            Self {
91                set_overall,
92                dense_layout,
93                dense_stride,
94                set_start_pos,
95                sparse_layout,
96                sparse_stride,
97                sparse_array_offset,
98                state_id_layout,
99            },
100        ))
101    }
102}
103
104/// This struct contains the sparse set functions
105#[derive(Debug)]
106pub struct SparseSetFunctions {
107    #[expect(dead_code)]
108    pub contains: FunctionIdx,
109    pub insert: FunctionIdx,
110}
111
112impl SparseSetFunctions {
113    /// Register all the sparse set functions and save their
114    /// [`FunctionIdx`]s.
115    pub fn new(ctx: &mut CompileContext, layout: &SparseSetLayout) -> Self {
116        let contains = ctx.add_function(Self::contains_fn(layout));
117        let insert = ctx.add_function(Self::insert_fn(layout, contains));
118
119        Self { contains, insert }
120    }
121
122    /// Returns a WASM function that will check whether a given state ID is
123    /// present in the set or not.
124    ///
125    /// If it is present, it returns `1`, else it returns `0`
126    fn contains_fn(layout: &SparseSetLayout) -> Function {
127        let mut locals_name_map = NameMap::new();
128        // Parameters
129        locals_name_map.append(0, "set_ptr");
130        locals_name_map.append(1, "set_len");
131        locals_name_map.append(2, "state_id");
132        // Locals
133        locals_name_map.append(3, "index");
134
135        let mut body = wasm_encoder::Function::new([(1, ValType::I32)]);
136        body.instructions()
137            // let index = self.sparse[id];
138            .local_get(2)
139            .i64_extend_i32_u()
140            .i64_const(layout.state_id_layout.size().try_into().unwrap())
141            .i64_mul() // need to scale the `state_id` index by the size of the elements of the array
142            .local_get(0)
143            .i64_add()
144            .state_id_load(
145                // sparse array is after dense
146                layout.sparse_array_offset.try_into().unwrap(),
147                &layout.state_id_layout,
148            )
149            .local_tee(3)
150            // index.as_usize() < self.len()
151            .local_get(1)
152            .i32_ge_u()
153            // returns `1` if `index.as_usize() >= self.len()`, meaning we should early exit
154            .if_(BlockType::Empty)
155            // `false` as an i32
156            .i32_const(0)
157            .return_()
158            .end()
159            // && self.dense[index] == id
160            .local_get(3)
161            .i64_extend_i32_u()
162            .i64_const(layout.state_id_layout.size().try_into().unwrap())
163            .i64_mul() // need to scale the `state_id` index by the size of the elements of the array
164            .local_get(0)
165            .i64_add()
166            .state_id_load(
167                // dense array is at offset 0
168                0,
169                &layout.state_id_layout,
170            )
171            .local_get(2)
172            .i32_eq()
173            .end();
174
175        Function {
176            sig: FunctionSignature {
177                name: "sparse_set_contains".into(),
178                params_ty: &[ValType::I64, ValType::I32, ValType::I32],
179                results_ty: &[ValType::I32],
180                export: false,
181            },
182            def: FunctionDefinition {
183                body,
184                locals_name_map,
185                labels_name_map: None,
186                branch_hints: None,
187            },
188        }
189    }
190
191    /// Returns a WASM function that will insert the state ID value into this
192    /// set and return `1` if the given state ID was not previously in this
193    /// set.
194    ///
195    /// This operation is idempotent. If the given value is already in this
196    /// set, then this is a no-op.
197    fn insert_fn(layout: &SparseSetLayout, contains: FunctionIdx) -> Function {
198        let mut locals_name_map = NameMap::new();
199        // Parameters
200        locals_name_map.append(0, "set_len");
201        locals_name_map.append(1, "state_id");
202        locals_name_map.append(2, "set_ptr");
203
204        let mut body = wasm_encoder::Function::new([]);
205        body.instructions()
206            // if self.contains(id) {
207            //     return set_len;
208            // }
209            .local_get(2) // set_ptr for contains
210            .local_get(0) // set_len for contains
211            .local_get(1) // state_id for contains
212            .call(contains.into())
213            .i32_const(true as i32)
214            .i32_eq()
215            .if_(BlockType::Empty)
216            .local_get(0) // return current set_len
217            .return_()
218            .end()
219            // self.dense[index] = id;
220            .local_get(0) // set_len as index
221            .i64_extend_i32_u()
222            .i64_const(layout.state_id_layout.size().try_into().unwrap())
223            .i64_mul() // need to scale the `state_id` index by the size of the elements of the array
224            .local_get(2) // set_ptr
225            .i64_add()
226            .local_get(1) // state_id
227            .state_id_store(
228                // dense is at offset 0
229                0,
230                &layout.state_id_layout,
231            )
232            // self.sparse[id] = index;
233            .local_get(1) // state_id
234            .i64_extend_i32_u()
235            .i64_const(layout.state_id_layout.size().try_into().unwrap())
236            .i64_mul() // need to scale the `state_id` index by the size of the elements of the array
237            .local_get(2) // set_ptr
238            .i64_add()
239            .local_get(0) // set_len as index
240            .state_id_store(
241                // sparse is after dense
242                layout.sparse_array_offset.try_into().unwrap(),
243                &layout.state_id_layout,
244            )
245            .local_get(0) // current set_len
246            .i32_const(1)
247            .i32_add()
248            .end();
249
250        Function {
251            sig: FunctionSignature {
252                name: "sparse_set_insert".into(),
253                // [set_len, state_id, set_ptr]
254                params_ty: &[ValType::I32, ValType::I32, ValType::I64],
255                // [new_set_len]
256                results_ty: &[ValType::I32],
257                export: false,
258            },
259            def: FunctionDefinition {
260                body,
261                locals_name_map,
262                labels_name_map: None,
263                branch_hints: None,
264            },
265        }
266    }
267}
268
269#[cfg(test)]
270pub mod tests {
271    use regex_automata::nfa::thompson::NFA;
272
273    use crate::compile::tests::setup_interpreter;
274
275    use super::*;
276
277    fn compile_test_module(layout: &SparseSetLayout) -> Vec<u8> {
278        let mut ctx = CompileContext::new(
279            NFA::never_match(),
280            crate::Config::new()
281                .export_all_functions(true)
282                .export_state(true),
283        );
284
285        let _funcs = SparseSetFunctions::new(&mut ctx, layout);
286
287        let module = ctx.compile(&layout.set_overall);
288        module.finish()
289    }
290
291    type SetContainsFn = wasmi::TypedFunc<(i64, i32, i32), i32>;
292    type SetInsertFn = wasmi::TypedFunc<(i32, i32, i64), i32>;
293
294    pub fn get_sparse_set_fns(
295        instance: &wasmi::Instance,
296        store: &wasmi::Store<()>,
297    ) -> (
298        SetContainsFn, // contains: (ptr, len, id) -> bool
299        SetInsertFn,   // insert: (len, id, ptr) -> new_len
300    ) {
301        let sparse_set_contains = instance
302            .get_typed_func::<(i64, i32, i32), i32>(&store, "sparse_set_contains")
303            .unwrap();
304
305        let sparse_set_insert = instance
306            .get_typed_func::<(i32, i32, i64), i32>(&store, "sparse_set_insert")
307            .unwrap();
308
309        (sparse_set_contains, sparse_set_insert)
310    }
311
312    #[test]
313    fn test_sparse_set_layout() {
314        // Lets do a non-standard layout to start
315        let overall = Layout::new::<u16>();
316        let state_id_layout = Layout::new::<u32>();
317
318        // Layout one sparse array at offset 0 with 5 states capacity.
319        let (overall, sparse_set_layout) =
320            SparseSetLayout::with_num_states(7, overall, &state_id_layout).unwrap();
321
322        assert_eq!(overall, Layout::from_size_align(60, 4).unwrap());
323
324        assert_eq!(
325            sparse_set_layout.dense_layout,
326            Layout::from_size_align(28, 4).unwrap()
327        );
328        assert_eq!(sparse_set_layout.dense_stride, 4);
329        assert_eq!(
330            sparse_set_layout.sparse_layout,
331            Layout::from_size_align(28, 4).unwrap()
332        );
333        assert_eq!(sparse_set_layout.sparse_stride, 4);
334
335        assert_eq!(
336            sparse_set_layout.set_overall,
337            Layout::from_size_align(56, 4).unwrap()
338        );
339        assert_eq!(sparse_set_layout.set_start_pos, 4);
340        assert_eq!(sparse_set_layout.sparse_array_offset, 28);
341    }
342
343    #[test]
344    fn test_init_insert_contains() {
345        let overall = Layout::new::<()>();
346        let state_id_layout = Layout::new::<u32>();
347
348        // Layout one sparse array at offset 0 with 5 states capacity.
349        let (_overall, sparse_set_layout) =
350            SparseSetLayout::with_num_states(5, overall, &state_id_layout).unwrap();
351        let module_bytes = compile_test_module(&sparse_set_layout);
352        let (_engine, _module, mut store, instance) = setup_interpreter(module_bytes);
353        let (contains, insert) = get_sparse_set_fns(&instance, &store);
354
355        let state_memory = instance.get_memory(&store, "state").unwrap();
356
357        let set_ptr = i64::from_ne_bytes(
358            u64::try_from(sparse_set_layout.set_start_pos)
359                .unwrap()
360                .to_ne_bytes(),
361        );
362        let set_len = 0;
363
364        let res = contains.call(&mut store, (set_ptr, set_len, 0)).unwrap();
365        // true because 0 was not present in the set
366        assert_eq!(res, false as i32);
367
368        let set_len = insert.call(&mut store, (set_len, 0, set_ptr)).unwrap();
369        assert_eq!(set_len, 1);
370
371        let res = contains.call(&mut store, (set_ptr, set_len, 0)).unwrap();
372        // true because 0 is already present in the set
373        assert_eq!(res, true as i32);
374
375        let set_len = insert.call(&mut store, (set_len, 0, set_ptr)).unwrap();
376        assert_eq!(set_len, 1);
377
378        let res = contains.call(&mut store, (set_ptr, set_len, 0)).unwrap();
379        // true because 0 is already present in the set
380        assert_eq!(res, true as i32);
381
382        let mut set_len = set_len;
383        for state_id in 1..5 {
384            let new_set_len = insert
385                .call(&mut store, (set_len, state_id, set_ptr))
386                .unwrap();
387            assert_eq!(new_set_len, set_len + 1);
388            set_len = new_set_len;
389        }
390
391        for state_id in 0..5 {
392            let res = contains
393                .call(&mut store, (set_ptr, set_len, state_id))
394                .unwrap();
395            // true because state is already present in the set
396            assert_eq!(res, true as i32, "{state_id} should be present");
397        }
398
399        #[rustfmt::skip]
400        assert_eq!(
401            &state_memory.data(&store)[..(state_id_layout.size() * 5 * 2)],
402            &[
403                0, 0, 0, 0, 1, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 0, 4, 0, 0, 0,
404                0, 0, 0, 0, 1, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 0, 4, 0, 0, 0,
405            ]
406        );
407
408        // Reset length, now set is empty
409        for state_id in 0..5 {
410            let res = contains.call(&mut store, (set_ptr, 0, state_id)).unwrap();
411            // true because state is already present in the set
412            assert_eq!(res, false as i32, "{state_id} should not be present");
413        }
414    }
415
416    #[test]
417    fn test_init_insert_reverse_contains() {
418        let overall = Layout::new::<()>();
419        let state_id_layout = Layout::new::<u32>();
420
421        // Layout one sparse array at offset 0 with 5 states capacity.
422        let (_overall, sparse_set_layout) =
423            SparseSetLayout::with_num_states(5, overall, &state_id_layout).unwrap();
424        let module_bytes = compile_test_module(&sparse_set_layout);
425        let (_engine, _module, mut store, instance) = setup_interpreter(module_bytes);
426        let (contains, insert) = get_sparse_set_fns(&instance, &store);
427
428        let state_memory = instance.get_memory(&store, "state").unwrap();
429
430        let set_ptr = 0;
431        let mut set_len = 0;
432
433        for state_id in [4, 1, 0, 2, 3] {
434            let res = contains
435                .call(&mut store, (set_ptr, set_len, state_id))
436                .unwrap();
437            // true because state is not present in the set
438            assert_eq!(res, false as i32, "{state_id} should not be present");
439        }
440
441        // inserting in weird order doesn't affect function
442
443        for state_id in [4, 1, 0, 2, 3] {
444            let new_set_len = insert
445                .call(&mut store, (set_len, state_id, set_ptr))
446                .unwrap();
447            // true because state_id was not present in the set
448            assert_eq!(new_set_len, set_len + 1);
449            set_len = new_set_len;
450        }
451
452        for state_id_check in [4, 1, 0, 2, 3] {
453            let res = contains
454                .call(&mut store, (set_ptr, set_len, state_id_check))
455                .unwrap();
456            // true because state is already present in the set
457            assert_eq!(res, true as i32, "{state_id_check} should be present");
458        }
459
460        #[rustfmt::skip]
461        assert_eq!(
462            &state_memory.data(&store)[..(state_id_layout.size() * 5 * 2)],
463            &[
464                4, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 0,
465                2, 0, 0, 0, 1, 0, 0, 0, 3, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0
466            ]
467        );
468    }
469
470    #[test]
471    fn test_state_id_larger_than_one_byte() {
472        let overall = Layout::new::<()>();
473        let state_id_layout = Layout::new::<u32>();
474
475        // Layout one sparse array at offset 0 with 512 states capacity.
476        let (_overall, sparse_set_layout) =
477            SparseSetLayout::with_num_states(512, overall, &state_id_layout).unwrap();
478        let module_bytes = compile_test_module(&sparse_set_layout);
479        let (_engine, _module, mut store, instance) = setup_interpreter(module_bytes);
480        let (contains, insert) = get_sparse_set_fns(&instance, &store);
481
482        let state_memory = instance.get_memory(&store, "state").unwrap();
483
484        let set_ptr = 0;
485        let set_len = 0;
486
487        let res = contains.call(&mut store, (set_ptr, set_len, 511)).unwrap();
488        assert_eq!(res, false as i32);
489
490        let set_len = insert.call(&mut store, (set_len, 256, set_ptr)).unwrap();
491        assert_eq!(set_len, 1);
492
493        let set_len = insert.call(&mut store, (set_len, 511, set_ptr)).unwrap();
494        assert_eq!(set_len, 2);
495
496        let res = contains.call(&mut store, (set_ptr, set_len, 511)).unwrap();
497        assert_eq!(res, true as i32);
498        let res = contains.call(&mut store, (set_ptr, set_len, 256)).unwrap();
499        assert_eq!(res, true as i32);
500
501        // dense entries
502        assert_eq!(
503            &state_memory.data(&store)[0..state_id_layout.size()],
504            &[0, 1, 0, 0]
505        );
506        assert_eq!(
507            &state_memory.data(&store)[state_id_layout.size()..(2 * state_id_layout.size())],
508            &[255, 1, 0, 0]
509        );
510
511        // sparse entries
512        assert_eq!(
513            &state_memory.data(&store)[(sparse_set_layout.sparse_array_offset
514                + 256 * state_id_layout.size())
515                ..(sparse_set_layout.sparse_array_offset + 257 * state_id_layout.size())],
516            &[0, 0, 0, 0]
517        );
518        assert_eq!(
519            &state_memory.data(&store)[(sparse_set_layout.sparse_array_offset
520                + 511 * state_id_layout.size())
521                ..(sparse_set_layout.sparse_array_offset + 512 * state_id_layout.size())],
522            &[1, 0, 0, 0]
523        );
524    }
525}