1use std::alloc::{Layout, LayoutError};
22
23use wasm_encoder::{BlockType, NameMap, ValType};
24
25use crate::util::repeat;
26
27use super::{
28 context::{Function, FunctionDefinition, FunctionIdx, FunctionSignature},
29 instructions::InstructionSinkExt,
30 CompileContext,
31};
32
33#[derive(Debug)]
39pub struct SparseSetLayout {
40 #[cfg_attr(not(test), expect(dead_code))]
41 dense_layout: Layout,
42 #[cfg_attr(not(test), expect(dead_code))]
43 dense_stride: usize,
44 #[cfg_attr(not(test), expect(dead_code))]
45 sparse_layout: Layout,
46 #[cfg_attr(not(test), expect(dead_code))]
47 sparse_stride: usize,
48
49 #[cfg_attr(not(test), expect(dead_code))]
50 pub set_overall: Layout,
51 pub set_start_pos: usize,
52 sparse_array_offset: usize,
53
54 state_id_layout: Layout,
55}
56
57impl SparseSetLayout {
58 pub fn new(ctx: &mut CompileContext, overall: Layout) -> Result<(Layout, Self), LayoutError> {
62 let num_states = ctx.nfa.states().len();
63
64 Self::with_num_states(num_states, overall, ctx.state_id_layout())
65 }
66
67 fn with_num_states(
68 num_states: usize,
69 overall: Layout,
70 state_id_layout: &Layout,
71 ) -> Result<(Layout, Self), LayoutError> {
72 let (dense_layout, dense_stride) = repeat(state_id_layout, num_states)?;
75
76 let (sparse_layout, sparse_stride) = repeat(state_id_layout, num_states)?;
79
80 let (set_overall, sparse_array_offset) = dense_layout.extend(sparse_layout)?;
81 let (overall, set_start_pos) = overall.extend(set_overall)?;
82
83 let state_id_layout = *state_id_layout;
87
88 Ok((
89 overall,
90 Self {
91 set_overall,
92 dense_layout,
93 dense_stride,
94 set_start_pos,
95 sparse_layout,
96 sparse_stride,
97 sparse_array_offset,
98 state_id_layout,
99 },
100 ))
101 }
102}
103
104#[derive(Debug)]
106pub struct SparseSetFunctions {
107 #[expect(dead_code)]
108 pub contains: FunctionIdx,
109 pub insert: FunctionIdx,
110}
111
112impl SparseSetFunctions {
113 pub fn new(ctx: &mut CompileContext, layout: &SparseSetLayout) -> Self {
116 let contains = ctx.add_function(Self::contains_fn(layout));
117 let insert = ctx.add_function(Self::insert_fn(layout, contains));
118
119 Self { contains, insert }
120 }
121
122 fn contains_fn(layout: &SparseSetLayout) -> Function {
127 let mut locals_name_map = NameMap::new();
128 locals_name_map.append(0, "set_ptr");
130 locals_name_map.append(1, "set_len");
131 locals_name_map.append(2, "state_id");
132 locals_name_map.append(3, "index");
134
135 let mut body = wasm_encoder::Function::new([(1, ValType::I32)]);
136 body.instructions()
137 .local_get(2)
139 .i64_extend_i32_u()
140 .i64_const(layout.state_id_layout.size().try_into().unwrap())
141 .i64_mul() .local_get(0)
143 .i64_add()
144 .state_id_load(
145 layout.sparse_array_offset.try_into().unwrap(),
147 &layout.state_id_layout,
148 )
149 .local_tee(3)
150 .local_get(1)
152 .i32_ge_u()
153 .if_(BlockType::Empty)
155 .i32_const(0)
157 .return_()
158 .end()
159 .local_get(3)
161 .i64_extend_i32_u()
162 .i64_const(layout.state_id_layout.size().try_into().unwrap())
163 .i64_mul() .local_get(0)
165 .i64_add()
166 .state_id_load(
167 0,
169 &layout.state_id_layout,
170 )
171 .local_get(2)
172 .i32_eq()
173 .end();
174
175 Function {
176 sig: FunctionSignature {
177 name: "sparse_set_contains".into(),
178 params_ty: &[ValType::I64, ValType::I32, ValType::I32],
179 results_ty: &[ValType::I32],
180 export: false,
181 },
182 def: FunctionDefinition {
183 body,
184 locals_name_map,
185 labels_name_map: None,
186 branch_hints: None,
187 },
188 }
189 }
190
191 fn insert_fn(layout: &SparseSetLayout, contains: FunctionIdx) -> Function {
198 let mut locals_name_map = NameMap::new();
199 locals_name_map.append(0, "set_len");
201 locals_name_map.append(1, "state_id");
202 locals_name_map.append(2, "set_ptr");
203
204 let mut body = wasm_encoder::Function::new([]);
205 body.instructions()
206 .local_get(2) .local_get(0) .local_get(1) .call(contains.into())
213 .i32_const(true as i32)
214 .i32_eq()
215 .if_(BlockType::Empty)
216 .local_get(0) .return_()
218 .end()
219 .local_get(0) .i64_extend_i32_u()
222 .i64_const(layout.state_id_layout.size().try_into().unwrap())
223 .i64_mul() .local_get(2) .i64_add()
226 .local_get(1) .state_id_store(
228 0,
230 &layout.state_id_layout,
231 )
232 .local_get(1) .i64_extend_i32_u()
235 .i64_const(layout.state_id_layout.size().try_into().unwrap())
236 .i64_mul() .local_get(2) .i64_add()
239 .local_get(0) .state_id_store(
241 layout.sparse_array_offset.try_into().unwrap(),
243 &layout.state_id_layout,
244 )
245 .local_get(0) .i32_const(1)
247 .i32_add()
248 .end();
249
250 Function {
251 sig: FunctionSignature {
252 name: "sparse_set_insert".into(),
253 params_ty: &[ValType::I32, ValType::I32, ValType::I64],
255 results_ty: &[ValType::I32],
257 export: false,
258 },
259 def: FunctionDefinition {
260 body,
261 locals_name_map,
262 labels_name_map: None,
263 branch_hints: None,
264 },
265 }
266 }
267}
268
269#[cfg(test)]
270pub mod tests {
271 use regex_automata::nfa::thompson::NFA;
272
273 use crate::compile::tests::setup_interpreter;
274
275 use super::*;
276
277 fn compile_test_module(layout: &SparseSetLayout) -> Vec<u8> {
278 let mut ctx = CompileContext::new(
279 NFA::never_match(),
280 crate::Config::new()
281 .export_all_functions(true)
282 .export_state(true),
283 );
284
285 let _funcs = SparseSetFunctions::new(&mut ctx, layout);
286
287 let module = ctx.compile(&layout.set_overall);
288 module.finish()
289 }
290
291 type SetContainsFn = wasmi::TypedFunc<(i64, i32, i32), i32>;
292 type SetInsertFn = wasmi::TypedFunc<(i32, i32, i64), i32>;
293
294 pub fn get_sparse_set_fns(
295 instance: &wasmi::Instance,
296 store: &wasmi::Store<()>,
297 ) -> (
298 SetContainsFn, SetInsertFn, ) {
301 let sparse_set_contains = instance
302 .get_typed_func::<(i64, i32, i32), i32>(&store, "sparse_set_contains")
303 .unwrap();
304
305 let sparse_set_insert = instance
306 .get_typed_func::<(i32, i32, i64), i32>(&store, "sparse_set_insert")
307 .unwrap();
308
309 (sparse_set_contains, sparse_set_insert)
310 }
311
312 #[test]
313 fn test_sparse_set_layout() {
314 let overall = Layout::new::<u16>();
316 let state_id_layout = Layout::new::<u32>();
317
318 let (overall, sparse_set_layout) =
320 SparseSetLayout::with_num_states(7, overall, &state_id_layout).unwrap();
321
322 assert_eq!(overall, Layout::from_size_align(60, 4).unwrap());
323
324 assert_eq!(
325 sparse_set_layout.dense_layout,
326 Layout::from_size_align(28, 4).unwrap()
327 );
328 assert_eq!(sparse_set_layout.dense_stride, 4);
329 assert_eq!(
330 sparse_set_layout.sparse_layout,
331 Layout::from_size_align(28, 4).unwrap()
332 );
333 assert_eq!(sparse_set_layout.sparse_stride, 4);
334
335 assert_eq!(
336 sparse_set_layout.set_overall,
337 Layout::from_size_align(56, 4).unwrap()
338 );
339 assert_eq!(sparse_set_layout.set_start_pos, 4);
340 assert_eq!(sparse_set_layout.sparse_array_offset, 28);
341 }
342
343 #[test]
344 fn test_init_insert_contains() {
345 let overall = Layout::new::<()>();
346 let state_id_layout = Layout::new::<u32>();
347
348 let (_overall, sparse_set_layout) =
350 SparseSetLayout::with_num_states(5, overall, &state_id_layout).unwrap();
351 let module_bytes = compile_test_module(&sparse_set_layout);
352 let (_engine, _module, mut store, instance) = setup_interpreter(module_bytes);
353 let (contains, insert) = get_sparse_set_fns(&instance, &store);
354
355 let state_memory = instance.get_memory(&store, "state").unwrap();
356
357 let set_ptr = i64::from_ne_bytes(
358 u64::try_from(sparse_set_layout.set_start_pos)
359 .unwrap()
360 .to_ne_bytes(),
361 );
362 let set_len = 0;
363
364 let res = contains.call(&mut store, (set_ptr, set_len, 0)).unwrap();
365 assert_eq!(res, false as i32);
367
368 let set_len = insert.call(&mut store, (set_len, 0, set_ptr)).unwrap();
369 assert_eq!(set_len, 1);
370
371 let res = contains.call(&mut store, (set_ptr, set_len, 0)).unwrap();
372 assert_eq!(res, true as i32);
374
375 let set_len = insert.call(&mut store, (set_len, 0, set_ptr)).unwrap();
376 assert_eq!(set_len, 1);
377
378 let res = contains.call(&mut store, (set_ptr, set_len, 0)).unwrap();
379 assert_eq!(res, true as i32);
381
382 let mut set_len = set_len;
383 for state_id in 1..5 {
384 let new_set_len = insert
385 .call(&mut store, (set_len, state_id, set_ptr))
386 .unwrap();
387 assert_eq!(new_set_len, set_len + 1);
388 set_len = new_set_len;
389 }
390
391 for state_id in 0..5 {
392 let res = contains
393 .call(&mut store, (set_ptr, set_len, state_id))
394 .unwrap();
395 assert_eq!(res, true as i32, "{state_id} should be present");
397 }
398
399 #[rustfmt::skip]
400 assert_eq!(
401 &state_memory.data(&store)[..(state_id_layout.size() * 5 * 2)],
402 &[
403 0, 0, 0, 0, 1, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 0, 4, 0, 0, 0,
404 0, 0, 0, 0, 1, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 0, 4, 0, 0, 0,
405 ]
406 );
407
408 for state_id in 0..5 {
410 let res = contains.call(&mut store, (set_ptr, 0, state_id)).unwrap();
411 assert_eq!(res, false as i32, "{state_id} should not be present");
413 }
414 }
415
416 #[test]
417 fn test_init_insert_reverse_contains() {
418 let overall = Layout::new::<()>();
419 let state_id_layout = Layout::new::<u32>();
420
421 let (_overall, sparse_set_layout) =
423 SparseSetLayout::with_num_states(5, overall, &state_id_layout).unwrap();
424 let module_bytes = compile_test_module(&sparse_set_layout);
425 let (_engine, _module, mut store, instance) = setup_interpreter(module_bytes);
426 let (contains, insert) = get_sparse_set_fns(&instance, &store);
427
428 let state_memory = instance.get_memory(&store, "state").unwrap();
429
430 let set_ptr = 0;
431 let mut set_len = 0;
432
433 for state_id in [4, 1, 0, 2, 3] {
434 let res = contains
435 .call(&mut store, (set_ptr, set_len, state_id))
436 .unwrap();
437 assert_eq!(res, false as i32, "{state_id} should not be present");
439 }
440
441 for state_id in [4, 1, 0, 2, 3] {
444 let new_set_len = insert
445 .call(&mut store, (set_len, state_id, set_ptr))
446 .unwrap();
447 assert_eq!(new_set_len, set_len + 1);
449 set_len = new_set_len;
450 }
451
452 for state_id_check in [4, 1, 0, 2, 3] {
453 let res = contains
454 .call(&mut store, (set_ptr, set_len, state_id_check))
455 .unwrap();
456 assert_eq!(res, true as i32, "{state_id_check} should be present");
458 }
459
460 #[rustfmt::skip]
461 assert_eq!(
462 &state_memory.data(&store)[..(state_id_layout.size() * 5 * 2)],
463 &[
464 4, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 0,
465 2, 0, 0, 0, 1, 0, 0, 0, 3, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0
466 ]
467 );
468 }
469
470 #[test]
471 fn test_state_id_larger_than_one_byte() {
472 let overall = Layout::new::<()>();
473 let state_id_layout = Layout::new::<u32>();
474
475 let (_overall, sparse_set_layout) =
477 SparseSetLayout::with_num_states(512, overall, &state_id_layout).unwrap();
478 let module_bytes = compile_test_module(&sparse_set_layout);
479 let (_engine, _module, mut store, instance) = setup_interpreter(module_bytes);
480 let (contains, insert) = get_sparse_set_fns(&instance, &store);
481
482 let state_memory = instance.get_memory(&store, "state").unwrap();
483
484 let set_ptr = 0;
485 let set_len = 0;
486
487 let res = contains.call(&mut store, (set_ptr, set_len, 511)).unwrap();
488 assert_eq!(res, false as i32);
489
490 let set_len = insert.call(&mut store, (set_len, 256, set_ptr)).unwrap();
491 assert_eq!(set_len, 1);
492
493 let set_len = insert.call(&mut store, (set_len, 511, set_ptr)).unwrap();
494 assert_eq!(set_len, 2);
495
496 let res = contains.call(&mut store, (set_ptr, set_len, 511)).unwrap();
497 assert_eq!(res, true as i32);
498 let res = contains.call(&mut store, (set_ptr, set_len, 256)).unwrap();
499 assert_eq!(res, true as i32);
500
501 assert_eq!(
503 &state_memory.data(&store)[0..state_id_layout.size()],
504 &[0, 1, 0, 0]
505 );
506 assert_eq!(
507 &state_memory.data(&store)[state_id_layout.size()..(2 * state_id_layout.size())],
508 &[255, 1, 0, 0]
509 );
510
511 assert_eq!(
513 &state_memory.data(&store)[(sparse_set_layout.sparse_array_offset
514 + 256 * state_id_layout.size())
515 ..(sparse_set_layout.sparse_array_offset + 257 * state_id_layout.size())],
516 &[0, 0, 0, 0]
517 );
518 assert_eq!(
519 &state_memory.data(&store)[(sparse_set_layout.sparse_array_offset
520 + 511 * state_id_layout.size())
521 ..(sparse_set_layout.sparse_array_offset + 512 * state_id_layout.size())],
522 &[1, 0, 0, 0]
523 );
524 }
525}