wahgex_core/compile/
context.rs

1//! This module defines the `CompileContext` and associated structures
2//! used for compiling a regular expression NFA into a WASM module.
3
4use std::{alloc::Layout, collections::BTreeMap};
5
6use regex_automata::nfa::thompson::NFA;
7use wasm_encoder::{
8    BranchHint, BranchHints, CodeSection, ConstExpr, DataSection, ExportKind, ExportSection,
9    FunctionSection, ImportSection, IndirectNameMap, MemorySection, MemoryType, Module, NameMap,
10    NameSection, TypeSection, ValType,
11};
12
13/// This struct contains all the input and intermediate state needed to compile
14/// the WASM module.
15#[derive(Debug)]
16#[non_exhaustive]
17pub struct CompileContext {
18    pub nfa: NFA,
19    pub config: crate::Config,
20    pub sections: Sections,
21    state_id_layout: Layout,
22}
23
24/// Contains the various sections of a WASM module being built.
25/// Declarations are added here, and definitions are stored for later assembly.
26#[derive(Debug, Default)]
27pub struct Sections {
28    types: TypeSection,
29    imports: ImportSection,
30    functions: FunctionSection,
31    memories: MemorySection,
32    exports: ExportSection,
33    data: DataSection,
34    function_names: NameMap,
35    memory_names: NameMap,
36    type_names: NameMap,
37    data_names: NameMap,
38
39    // Stores function definitions, keyed by FunctionIdx.0, to be assembled later.
40    function_definitions: BTreeMap<u32, FunctionDefinition>,
41}
42
43impl Sections {
44    /// Adds an active data segment to the data section.
45    /// These segments are copied into a linear memory at a specified offset
46    /// during instantiation. Currently, all active data segments are
47    /// hardcoded to target memory index 1 (state memory).
48    pub fn add_active_data_segment(&mut self, segment: ActiveDataSegment) {
49        let offset = ConstExpr::i64_const(
50            segment
51                .position
52                .try_into()
53                .expect("Data segment position too large for i64"),
54        );
55        let data_idx = self.data.len();
56        // TODO: Make the memory index configurable or determined dynamically if
57        // multiple memories are used beyond haystack (0) and state (1).
58        self.data.active(1, &offset, segment.data.iter().copied());
59        self.data_names.append(data_idx, &segment.name);
60    }
61}
62
63impl CompileContext {
64    /// Creates a new `CompileContext` with the given NFA and configuration.
65    pub fn new(nfa: NFA, config: crate::Config) -> Self {
66        let state_id_layout = Self::compute_state_id_layout(&nfa);
67
68        Self {
69            nfa,
70            config,
71            sections: Sections::default(),
72            state_id_layout,
73        }
74    }
75
76    /// Declare and define a function.
77    pub fn add_function(&mut self, func: Function) -> FunctionIdx {
78        let func_idx = self.declare_function(func.sig);
79        self.define_function(func_idx, func.def);
80        func_idx
81    }
82
83    /// Declares a function's signature (name, parameters, return types, export
84    /// status).
85    ///
86    /// This adds entries to the Type, Function, and potentially
87    /// Export sections. A `FunctionIdx` is returned, which should be used
88    /// later to provide the definition.
89    pub fn declare_function(&mut self, sig: FunctionSignature) -> FunctionIdx {
90        let func_ty_idx = self.sections.types.len();
91        self.sections.types.ty().function(
92            sig.params_ty.iter().copied(),
93            sig.results_ty.iter().copied(),
94        );
95        self.sections
96            .type_names
97            .append(func_ty_idx, &sig.type_name());
98
99        let func_idx_val = self.sections.functions.len();
100        self.sections.functions.function(func_ty_idx);
101        self.sections.function_names.append(func_idx_val, &sig.name);
102
103        #[cfg(test)]
104        let override_export = self.config.get_export_all_functions();
105        #[cfg(not(test))]
106        let override_export = false;
107
108        if sig.export || override_export {
109            self.sections
110                .exports
111                .export(&sig.name, ExportKind::Func, func_idx_val);
112        }
113        FunctionIdx(func_idx_val)
114    }
115
116    /// Defines a previously declared function.
117    ///
118    /// The `func_idx` must correspond to a function previously returned by
119    /// `declare_function`. The definition includes the body, local names,
120    /// label names, and branch hints.
121    pub fn define_function(&mut self, func_idx: FunctionIdx, def: FunctionDefinition) {
122        if func_idx.0 >= self.sections.functions.len() {
123            panic!(
124                "Defining function with index {} which has not been declared (max declared index: \
125                 {})",
126                func_idx.0,
127                if !self.sections.functions.is_empty() {
128                    self.sections.functions.len() - 1
129                } else {
130                    0
131                }
132            );
133        }
134        if self
135            .sections
136            .function_definitions
137            .insert(func_idx.0, def)
138            .is_some()
139        {
140            panic!("Warning: Redefining function at index {}", func_idx.0);
141        }
142    }
143
144    /// Adds a block signature to the type section.
145    ///
146    /// This is used for block types in control flow instructions.
147    pub fn add_block_signature(&mut self, signature: BlockSignature) -> TypeIdx {
148        let block_ty_idx = self.sections.types.len();
149        self.sections.types.ty().function(
150            signature.params_ty.iter().copied(),
151            signature.results_ty.iter().copied(),
152        );
153        self.sections
154            .type_names
155            .append(block_ty_idx, &signature.type_name());
156        TypeIdx(block_ty_idx)
157    }
158
159    /// Return the minimal layout for a state ID for the current NFA.
160    ///
161    /// This function will minimize the size of the state ID layout based on the
162    /// number of states in the current NFA.
163    pub fn state_id_layout(&self) -> &Layout {
164        &self.state_id_layout
165    }
166
167    fn compute_state_id_layout(nfa: &NFA) -> Layout {
168        let num_states = nfa.states().len();
169
170        if num_states <= u8::MAX as usize {
171            Layout::from_size_align(1, 1).unwrap()
172        } else if num_states <= u16::MAX as usize {
173            Layout::from_size_align(2, 2).unwrap()
174        } else {
175            Layout::from_size_align(4, 4).unwrap()
176        }
177    }
178}
179
180impl CompileContext {
181    /// This function takes all the individual settings/functions/data
182    /// segments/layouts and compiles them into a single WASM [`Module`].
183    pub fn compile(mut self, state_overall: &Layout) -> Module {
184        let mut module = Module::new();
185
186        // Section order
187        //  Types
188        //  Imports
189        //  Functions
190        //  Tables
191        //  Memories
192        //  Globals
193        //  Exports
194        //  Start
195        //  Elements
196        //  Data Count
197        //  Code
198        //  Data
199
200        module.section(&self.sections.types);
201
202        module.section(&self.sections.imports);
203
204        module.section(&self.sections.functions);
205
206        // Determine minimum (and maximum?) size based on data structure layout
207        let haystack_mem_idx = self.sections.memories.len();
208        self.sections.memories.memory(MemoryType {
209            minimum: 1,
210            maximum: None,
211            // TODO: Make state memory64 default false by config
212            memory64: true,
213            shared: false,
214            // TODO: Use custom page size
215            page_size_log2: None,
216        });
217        let state_mem_idx = self.sections.memories.len();
218        let state_mem_size =
219            1 + u64::try_from((state_overall.size() - 1) / self.config.get_page_size()).unwrap();
220        self.sections.memories.memory(MemoryType {
221            minimum: state_mem_size,
222            maximum: Some(state_mem_size),
223            // TODO: Make state memory64 default false by config
224            memory64: true,
225            shared: false,
226            // TODO: Use custom page size
227            page_size_log2: None,
228        });
229        module.section(&self.sections.memories);
230
231        self.sections
232            .exports
233            .export("haystack", ExportKind::Memory, haystack_mem_idx);
234
235        #[cfg(test)]
236        let export_state = self.config.get_export_state();
237        #[cfg(not(test))]
238        let export_state = false;
239        if export_state {
240            self.sections
241                .exports
242                .export("state", ExportKind::Memory, state_mem_idx);
243        }
244        module.section(&self.sections.exports);
245
246        // Build CodeSection, BranchHints, and name maps for locals/labels from
247        // definitions
248        let mut codes = CodeSection::new();
249        let mut hint_section = BranchHints::new();
250        let mut local_names = IndirectNameMap::new();
251        let mut label_names = IndirectNameMap::new();
252
253        let num_declared_functions = self.sections.functions.len();
254
255        // Ensure all declared functions have corresponding definitions.
256        // BTreeMap iteration is ordered by key, which is FunctionIdx.0.
257        // We iterate 0..num_declared_functions to ensure correct order and that all are
258        // present.
259        for func_idx_val in 0..num_declared_functions {
260            match self.sections.function_definitions.get(&func_idx_val) {
261                Some(def) => {
262                    codes.function(&def.body);
263                    local_names.append(func_idx_val, &def.locals_name_map);
264                    if let Some(labels) = &def.labels_name_map {
265                        label_names.append(func_idx_val, labels);
266                    }
267                    if let Some(hints) = &def.branch_hints {
268                        hint_section.function_hints(func_idx_val, hints.iter().copied());
269                    }
270                },
271                None => {
272                    panic!("Function at index {func_idx_val} was declared but not defined.");
273                },
274            }
275        }
276
277        module.section(&hint_section);
278
279        module.section(&codes);
280
281        module.section(&self.sections.data);
282
283        let mut name_section = NameSection::new();
284        {
285            name_section.functions(&self.sections.function_names);
286
287            {
288                self.sections
289                    .memory_names
290                    .append(haystack_mem_idx, "haystack");
291                self.sections.memory_names.append(state_mem_idx, "state"); // Assuming state_mem_idx is valid
292            }
293            name_section.memories(&self.sections.memory_names);
294
295            name_section.locals(&local_names);
296
297            name_section.labels(&label_names);
298
299            name_section.types(&self.sections.type_names);
300
301            name_section.data(&self.sections.data_names);
302        }
303        module.section(&name_section);
304
305        module
306    }
307}
308
309/// Represents an active data segment to be included in the WASM module.
310#[derive(Debug)]
311pub struct ActiveDataSegment {
312    pub name: String,
313    pub position: usize,
314    pub data: Vec<u8>,
315}
316
317/// Describes the signature of a function: its name, parameters, results, and
318/// export status.
319#[derive(Debug)]
320pub struct FunctionSignature {
321    pub name: String,
322    pub params_ty: &'static [ValType],
323    pub results_ty: &'static [ValType],
324    pub export: bool,
325}
326
327impl FunctionSignature {
328    /// Generates a unique name for this function's type signature.
329    fn type_name(&self) -> String {
330        format!("{}_fn", self.name)
331    }
332}
333
334/// Contains the definition of a function: its body, local names, label names,
335/// and branch hints.
336///
337/// This is associated with a `FunctionIdx` obtained from
338/// [`CompileContext::declare_function`].
339#[derive(Debug)]
340pub struct FunctionDefinition {
341    pub body: wasm_encoder::Function,
342    pub locals_name_map: NameMap,
343    pub labels_name_map: Option<NameMap>,
344    pub branch_hints: Option<Vec<BranchHint>>,
345}
346
347/// Contains the full definition of a function: signature and definition.
348#[derive(Debug)]
349pub struct Function {
350    pub sig: FunctionSignature,
351    pub def: FunctionDefinition,
352}
353
354/// Describes the signature of a block type (e.g., for `if`, `loop`, `block`).
355/// It includes a descriptive name, parameter types, and result types.
356#[derive(Debug)]
357pub struct BlockSignature {
358    pub name: &'static str,
359    pub params_ty: &'static [ValType],
360    pub results_ty: &'static [ValType],
361}
362
363impl BlockSignature {
364    /// Generates a unique name for this block's type signature.
365    fn type_name(&self) -> String {
366        format!("{}_block_sig", self.name)
367    }
368}
369
370/// This index type represents a pointer to a specific [`Function`].
371#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
372pub struct FunctionIdx(u32);
373
374impl From<FunctionIdx> for u32 {
375    fn from(idx: FunctionIdx) -> Self {
376        idx.0
377    }
378}
379
380/// This index type represents a pointer to a specific type, be it function or
381/// block signature.
382#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
383pub struct TypeIdx(u32);
384
385impl From<TypeIdx> for u32 {
386    fn from(idx: TypeIdx) -> Self {
387        idx.0
388    }
389}