1use std::alloc::{Layout, LayoutError};
5
6use regex_automata::{nfa::thompson::NFA, Anchored};
7use wasm_encoder::{BlockType, NameMap, ValType};
8
9use super::context::{
10 BlockSignature, CompileContext, Function, FunctionDefinition, FunctionIdx, FunctionSignature,
11 TypeIdx,
12};
13
14#[derive(Debug)]
17#[repr(C)]
18pub struct InputOpts {
19 pub earliest: i32,
21 pub anchored: i32,
28 pub anchored_pattern: i32,
33}
34
35impl InputOpts {
36 pub fn new(input: ®ex_automata::Input<'_>) -> InputOpts {
41 let (anchored, anchored_pattern) = match input.get_anchored() {
42 Anchored::No => (0, 0),
43 Anchored::Yes => (1, 0),
44 Anchored::Pattern(id) => (2, i32::from_ne_bytes(id.to_ne_bytes())),
45 };
46
47 InputOpts {
48 earliest: input.get_earliest() as i32,
49 anchored,
50 anchored_pattern,
51 }
52 }
53}
54
55#[derive(Debug)]
60pub struct InputLayout {
61 pub haystack_start_pos: usize,
62 _overall: Layout,
63}
64
65impl InputLayout {
66 pub fn new(_ctx: &mut CompileContext) -> Result<Self, LayoutError> {
71 let overall = Layout::new::<()>();
72
73 let (overall, haystack_start_pos) = overall.extend(Layout::array::<u8>(0)?)?;
76
77 Ok(Self {
78 _overall: overall,
79 haystack_start_pos,
80 })
81 }
82}
83
84#[derive(Debug)]
89pub struct InputFunctions {
90 #[expect(dead_code)]
91 prepare_input: FunctionIdx,
92 pub assert_input_args_wf: FunctionIdx,
93 pub utf8_is_boundary: Option<FunctionIdx>,
94 pub start_config: FunctionIdx,
95}
96
97impl InputFunctions {
98 pub fn new(
105 ctx: &mut CompileContext,
106 input_layout: &InputLayout,
107 pattern_lookup_start: FunctionIdx,
108 ) -> Self {
109 let prepare_input = ctx.add_function(Self::prepare_input_fn(
110 ctx.config.get_page_size(),
111 input_layout,
112 ));
113 let assert_input_args_wf = ctx.add_function(Self::assert_input_args_wf_fn(
114 ctx.config.get_page_size(),
115 input_layout,
116 ));
117
118 let utf8_is_boundary = (ctx.nfa.has_empty() && ctx.nfa.is_utf8())
119 .then(|| ctx.add_function(Self::utf8_is_boundary_fn()));
120
121 let pattern_lookup_start_result_block_sig = ctx.add_block_signature(BlockSignature {
122 name: "pattern_lookup_start_result",
123 params_ty: &[ValType::I32],
124 results_ty: &[],
125 });
126
127 let start_config = ctx.add_function(Self::start_config_fn(
128 &ctx.nfa,
129 pattern_lookup_start,
130 pattern_lookup_start_result_block_sig,
131 ));
132
133 Self {
134 prepare_input,
135 assert_input_args_wf,
136 utf8_is_boundary,
137 start_config,
138 }
139 }
140
141 fn start_config_fn(
142 nfa: &NFA,
143 pattern_lookup_start: FunctionIdx,
144 pattern_lookup_start_result_block_sig: TypeIdx,
145 ) -> Function {
146 let mut locals_name_map = NameMap::new();
149 locals_name_map.append(0, "anchored");
151 locals_name_map.append(1, "anchored_pattern");
152 locals_name_map.append(2, "pattern_start");
154
155 let mut body = wasm_encoder::Function::new([(1, ValType::I32)]);
173 body.instructions()
174 .local_get(0)
176 .i32_const(0) .i32_eq()
178 .if_(BlockType::Empty)
179 .i32_const(i32::from_ne_bytes(nfa.start_anchored().to_ne_bytes()))
181 .i32_const(nfa.is_always_start_anchored() as i32)
182 .i32_const(true as i32)
183 .return_()
184 .end()
185 .local_get(0)
187 .i32_const(1) .i32_eq()
189 .if_(BlockType::Empty)
190 .i32_const(i32::from_ne_bytes(nfa.start_anchored().to_ne_bytes()))
192 .i32_const(true as i32)
193 .i32_const(true as i32)
194 .return_()
195 .end()
196 .local_get(0)
198 .i32_const(2) .i32_eq()
200 .if_(BlockType::Empty)
201 .local_get(1)
203 .call(pattern_lookup_start.into())
204 .if_(BlockType::FunctionType(
205 pattern_lookup_start_result_block_sig.into(),
206 ))
207 .i32_const(true as i32)
209 .i32_const(true as i32)
210 .return_()
211 .else_()
212 .drop()
213 .end()
214 .end()
215 .i32_const(0)
217 .i32_const(0)
218 .i32_const(false as i32)
219 .end();
220
221 Function {
222 sig: FunctionSignature {
223 name: "start_config".into(),
224 params_ty: &[ValType::I32, ValType::I32],
226 results_ty: &[ValType::I32, ValType::I32, ValType::I32],
228 export: false,
229 },
230 def: FunctionDefinition {
231 body,
232 locals_name_map,
233 labels_name_map: None,
234 branch_hints: None,
235 },
236 }
237 }
238
239 fn utf8_is_boundary_fn() -> Function {
240 let mut locals_name_map = NameMap::new();
253 locals_name_map.append(0, "haystack_ptr");
255 locals_name_map.append(1, "haystack_len");
256 locals_name_map.append(2, "at_offset");
257 locals_name_map.append(3, "byte");
259
260 let mut body = wasm_encoder::Function::new([(1, ValType::I32)]);
261 body.instructions()
262 .local_get(2)
264 .local_get(1)
265 .i64_ge_u()
266 .if_(BlockType::Empty)
267 .local_get(2)
269 .local_get(1)
270 .i64_eq() .return_()
272 .end()
273 .local_get(0)
275 .local_get(2)
276 .i64_add()
277 .i32_load8_u(wasm_encoder::MemArg {
278 offset: 0, align: 0, memory_index: 0, })
282 .local_set(3)
283 .local_get(3)
285 .i32_const(0b0111_1111)
286 .i32_le_u()
287 .local_get(3)
288 .i32_const(0b1100_0000)
289 .i32_ge_u()
290 .i32_or()
291 .end();
292
293 Function {
294 sig: FunctionSignature {
295 name: "utf8_is_boundary".into(),
296 params_ty: &[ValType::I64, ValType::I64, ValType::I64],
298 results_ty: &[ValType::I32],
300 export: false,
301 },
302 def: FunctionDefinition {
303 body,
304 locals_name_map,
305 labels_name_map: None,
306 branch_hints: None,
307 },
308 }
309 }
310
311 fn assert_input_args_wf_fn(page_size: usize, input_layout: &InputLayout) -> Function {
312 let mut locals_name_map = NameMap::new();
313 locals_name_map.append(0, "earliest");
315 locals_name_map.append(1, "anchored");
316 locals_name_map.append(2, "anchored_pattern");
317 locals_name_map.append(3, "span_start");
318 locals_name_map.append(4, "span_end");
319 locals_name_map.append(5, "haystack_len");
320
321 let mut labels_name_map = NameMap::new();
322 labels_name_map.append(0, "check_earliest_wf");
323 labels_name_map.append(1, "check_anchored_wf");
324 labels_name_map.append(2, "check_anchored_pattern_wf");
325 labels_name_map.append(3, "check_span_start_wf");
326 labels_name_map.append(4, "check_span_end_wf");
327
328 let mut body = wasm_encoder::Function::new([]);
329 body.instructions()
330 .local_get(0)
332 .i32_const(true as i32)
333 .i32_ne()
334 .local_get(0)
335 .i32_const(false as i32)
336 .i32_ne()
337 .i32_and()
338 .if_(BlockType::Empty)
339 .unreachable()
340 .end()
341 .local_get(1)
343 .i32_const(0)
344 .i32_ne()
345 .local_get(1)
346 .i32_const(1)
347 .i32_ne()
348 .local_get(1)
349 .i32_const(2)
350 .i32_ne()
351 .i32_and()
352 .i32_and()
353 .if_(BlockType::Empty)
354 .unreachable()
355 .end()
356 .local_get(1)
358 .i32_const(2)
359 .i32_ne()
360 .local_get(2)
361 .i32_const(0)
362 .i32_ne()
363 .i32_and()
364 .if_(BlockType::Empty)
365 .unreachable()
366 .end()
367 .local_get(3)
369 .local_get(4)
370 .i64_gt_u()
371 .if_(BlockType::Empty)
372 .unreachable()
373 .end()
374 .local_get(4)
376 .local_get(5)
377 .i64_gt_u()
378 .if_(BlockType::Empty)
379 .unreachable()
380 .end()
381 .i64_const(i64::from_ne_bytes(
383 u64::try_from(input_layout.haystack_start_pos)
384 .unwrap()
385 .to_ne_bytes(),
386 ))
387 .local_get(5)
388 .i64_add()
389 .memory_size(0)
390 .i64_const(i64::from_ne_bytes(
391 u64::try_from(page_size).unwrap().to_ne_bytes(),
392 ))
393 .i64_mul()
394 .i64_gt_u()
395 .if_(BlockType::Empty)
396 .unreachable()
397 .end()
398 .end();
399
400 Function {
401 sig: FunctionSignature {
402 name: "assert_input_args_wf".into(),
403 params_ty: &[
404 ValType::I32,
405 ValType::I32,
406 ValType::I32,
407 ValType::I64,
408 ValType::I64,
409 ValType::I64,
410 ],
411 results_ty: &[],
412 export: false,
413 },
414 def: FunctionDefinition {
415 body,
416 locals_name_map,
417 labels_name_map: Some(labels_name_map),
418 branch_hints: None,
419 },
420 }
421 }
422
423 fn prepare_input_fn(page_size: usize, input_layout: &InputLayout) -> Function {
424 let mut locals_name_map = NameMap::new();
425 locals_name_map.append(0, "haystack_len");
427 locals_name_map.append(1, "num_new_page_required");
429
430 let mut body = wasm_encoder::Function::new([(1, ValType::I64)]);
431 body.instructions()
432 .local_get(0) .i64_const(i64::from_ne_bytes(
435 u64::try_from(input_layout.haystack_start_pos)
436 .unwrap()
437 .to_ne_bytes(),
438 ))
439 .i64_add()
440 .i64_const(0)
441 .i64_eq()
442 .if_(BlockType::Empty)
443 .i32_const(PrepareInputResult::SuccessNoGrowth as i32)
445 .return_()
446 .end()
447 .local_get(0) .i64_const(i64::from_ne_bytes(
450 u64::try_from(input_layout.haystack_start_pos)
451 .unwrap()
452 .to_ne_bytes(),
453 ))
454 .i64_add()
455 .i64_const(1)
456 .i64_sub()
457 .i64_const(i64::from_ne_bytes(
458 u64::try_from(page_size).unwrap().to_ne_bytes(),
459 ))
460 .i64_div_u()
461 .i64_const(1)
462 .i64_add()
463 .memory_size(0)
464 .i64_sub()
465 .local_tee(1)
466 .i64_const(0)
467 .i64_gt_s()
471 .if_(BlockType::Result(ValType::I32))
472 .local_get(1)
473 .memory_grow(0)
474 .i64_const(-1)
475 .i64_eq()
476 .if_(BlockType::Empty)
477 .i32_const(PrepareInputResult::Failure as i32)
479 .return_()
480 .end()
481 .i32_const(PrepareInputResult::SuccessGrowth as i32)
482 .else_()
483 .i32_const(PrepareInputResult::SuccessNoGrowth as i32)
484 .end()
485 .end();
486
487 Function {
488 sig: FunctionSignature {
489 name: "prepare_input".into(),
490 params_ty: &[ValType::I64],
492 results_ty: &[ValType::I32],
494 export: true,
495 },
496 def: FunctionDefinition {
497 body,
498 locals_name_map,
499 labels_name_map: None,
500 branch_hints: None,
501 },
502 }
503 }
504}
505
506#[derive(Debug)]
508pub enum PrepareInputResult {
509 SuccessNoGrowth = 0,
512 SuccessGrowth = 1,
515 Failure = 2,
518}
519
520#[cfg(test)]
521mod tests {
522 use regex_automata::nfa::thompson::NFA;
523
524 use crate::compile::{
525 pattern::{PatternFunctions, PatternLayout},
526 tests::setup_interpreter,
527 };
528
529 use super::*;
530
531 #[test]
532 fn prepare_input() {
533 let mut ctx = CompileContext::new(
534 NFA::always_match(),
535 crate::Config::new()
536 .export_all_functions(true)
537 .export_state(true),
538 );
539
540 let state_overall = Layout::new::<()>();
541 let (state_overall, pattern_layout) = PatternLayout::new(&mut ctx, state_overall).unwrap();
542 let pattern_functions = PatternFunctions::new(&mut ctx, &pattern_layout);
543
544 let input_layout = InputLayout::new(&mut ctx).unwrap();
545 let _input_functions =
546 InputFunctions::new(&mut ctx, &input_layout, pattern_functions.lookup_start);
547 let page_size = ctx.config.get_page_size();
548
549 let module = ctx.compile(&state_overall);
550 let module_bytes = module.finish();
551 let (_engine, _module, mut store, instance) = setup_interpreter(&module_bytes);
552 let haystack_memory = instance.get_memory(&store, "haystack").unwrap();
553 let prepare_input = instance
554 .get_typed_func::<i64, i32>(&store, "prepare_input")
555 .unwrap();
556
557 let haystack_size = haystack_memory.size(&store);
558 assert_eq!(haystack_size, 1);
559
560 let haystack_len = 0;
561 let res = prepare_input.call(&mut store, haystack_len).unwrap();
562 assert_eq!(res, PrepareInputResult::SuccessNoGrowth as i32);
563
564 let haystack_size = haystack_memory.size(&store);
565 assert_eq!(haystack_size, 1);
566
567 let haystack_len = 1;
568 let res = prepare_input.call(&mut store, haystack_len).unwrap();
569 assert_eq!(res, PrepareInputResult::SuccessNoGrowth as i32);
570
571 let haystack_size = haystack_memory.size(&store);
572 assert_eq!(haystack_size, 1);
573
574 let haystack_len = i64::try_from(page_size - input_layout._overall.size()).unwrap();
577 let res = prepare_input.call(&mut store, haystack_len).unwrap();
578 assert_eq!(res, PrepareInputResult::SuccessNoGrowth as i32);
579
580 let haystack_size = haystack_memory.size(&store);
581 assert_eq!(haystack_size, 1);
582
583 let haystack_len =
585 i64::try_from(page_size - input_layout._overall.size() + page_size).unwrap();
586 let res = prepare_input.call(&mut store, haystack_len).unwrap();
587 assert_eq!(res, PrepareInputResult::SuccessGrowth as i32);
588
589 let haystack_size = haystack_memory.size(&store);
590 assert_eq!(haystack_size, 2);
591
592 let haystack_len_for_negative_case = 1_i64; let res = prepare_input
603 .call(&mut store, haystack_len_for_negative_case)
604 .unwrap();
605 assert_eq!(
606 res,
607 PrepareInputResult::SuccessNoGrowth as i32,
608 "Should be SuccessNoGrowth when current pages > required pages"
609 );
610 assert_eq!(
611 haystack_memory.size(&store),
612 2,
613 "Memory size should remain 2 pages"
614 );
615 }
616}