blart/raw/representation/
inner_node_256.rs

1use crate::raw::{
2    Header, InnerNode, InnerNode48, Node, NodeType, OpaqueNodePtr, RestrictedNodeIndex,
3};
4use std::{
5    fmt,
6    iter::{Enumerate, FusedIterator},
7    mem::MaybeUninit,
8    ops::Bound,
9    slice::Iter,
10};
11
12#[cfg(feature = "nightly")]
13use std::{
14    iter::FilterMap,
15    simd::{cmp::SimdPartialEq, usizex64},
16};
17
18/// Node that references between 49 and 256 children
19#[repr(C, align(8))]
20pub struct InnerNode256<K, V, const PREFIX_LEN: usize> {
21    /// The common node fields.
22    pub header: Header<PREFIX_LEN>,
23    /// An array that directly maps a key byte (as index) to a child node.
24    pub child_pointers: [Option<OpaqueNodePtr<K, V, PREFIX_LEN>>; 256],
25}
26
27impl<K, V, const PREFIX_LEN: usize> fmt::Debug for InnerNode256<K, V, PREFIX_LEN> {
28    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
29        f.debug_struct("InnerNode256")
30            .field("header", &self.header)
31            .field("child_pointers", &self.child_pointers)
32            .finish()
33    }
34}
35
36impl<K, V, const PREFIX_LEN: usize> Clone for InnerNode256<K, V, PREFIX_LEN> {
37    fn clone(&self) -> Self {
38        Self {
39            header: self.header.clone(),
40            child_pointers: self.child_pointers,
41        }
42    }
43}
44
45impl<K, V, const PREFIX_LEN: usize> Node<PREFIX_LEN> for InnerNode256<K, V, PREFIX_LEN> {
46    type Key = K;
47    type Value = V;
48
49    const TYPE: NodeType = NodeType::Node256;
50}
51
52impl<K, V, const PREFIX_LEN: usize> InnerNode<PREFIX_LEN> for InnerNode256<K, V, PREFIX_LEN> {
53    type GrownNode = Self;
54    #[cfg(not(feature = "nightly"))]
55    type Iter<'a>
56        = Node256Iter<'a, K, V, PREFIX_LEN>
57    where
58        Self: 'a;
59    #[cfg(feature = "nightly")]
60    type Iter<'a>
61        = FilterMap<
62        Enumerate<Iter<'a, Option<OpaqueNodePtr<K, V, PREFIX_LEN>>>>,
63        impl FnMut(
64            (usize, &'a Option<OpaqueNodePtr<K, V, PREFIX_LEN>>),
65        ) -> Option<(u8, OpaqueNodePtr<K, V, PREFIX_LEN>)>,
66    >
67    where
68        Self: 'a;
69    type ShrunkNode = InnerNode48<K, V, PREFIX_LEN>;
70
71    fn header(&self) -> &Header<PREFIX_LEN> {
72        &self.header
73    }
74
75    fn from_header(header: Header<PREFIX_LEN>) -> Self {
76        InnerNode256 {
77            header,
78            child_pointers: [None; 256],
79        }
80    }
81
82    fn lookup_child(&self, key_fragment: u8) -> Option<OpaqueNodePtr<K, V, PREFIX_LEN>> {
83        self.child_pointers[usize::from(key_fragment)]
84    }
85
86    fn write_child(&mut self, key_fragment: u8, child_pointer: OpaqueNodePtr<K, V, PREFIX_LEN>) {
87        let key_fragment_idx = usize::from(key_fragment);
88        let existing_pointer = self.child_pointers[key_fragment_idx];
89        self.child_pointers[key_fragment_idx] = Some(child_pointer);
90        if existing_pointer.is_none() {
91            self.header.inc_num_children();
92        }
93    }
94
95    fn remove_child(&mut self, key_fragment: u8) -> Option<OpaqueNodePtr<K, V, PREFIX_LEN>> {
96        let removed_child = self.child_pointers[usize::from(key_fragment)].take();
97
98        if removed_child.is_some() {
99            self.header.dec_num_children();
100        }
101
102        removed_child
103    }
104
105    fn grow(&self) -> Self::GrownNode {
106        panic!("unable to grow a Node256, something went wrong!")
107    }
108
109    fn shrink(&self) -> Self::ShrunkNode {
110        assert!(
111            self.header.num_children() <= 48,
112            "Cannot shrink a Node256 when it has more than 48 children. Currently has [{}] \
113             children.",
114            self.header.num_children()
115        );
116
117        let header = self.header.clone();
118        let mut child_indices = [RestrictedNodeIndex::<48>::EMPTY; 256];
119        let mut child_pointers = [MaybeUninit::uninit(); 48];
120
121        for (child_index, (key_byte, child_ptr)) in self.iter().enumerate() {
122            // PANIC SAFETY: This `try_from` will not panic because the `next_index` value
123            // is guaranteed to be 48 or less by the `assert!(num_children < 48)` at the
124            // start of the function.
125            let key_byte = usize::from(key_byte);
126            child_indices[key_byte] = RestrictedNodeIndex::try_from(child_index).unwrap();
127            child_pointers[child_index].write(child_ptr);
128        }
129
130        InnerNode48 {
131            header,
132            child_indices,
133            child_pointers,
134        }
135    }
136
137    fn iter(&self) -> Self::Iter<'_> {
138        #[cfg(not(feature = "nightly"))]
139        {
140            Node256Iter {
141                it: self.child_pointers.iter().enumerate(),
142            }
143        }
144
145        #[cfg(feature = "nightly")]
146        {
147            self.child_pointers
148                .iter()
149                .enumerate()
150                .filter_map(|(key, node)| node.map(|node| (key as u8, node)))
151        }
152    }
153
154    fn range(
155        &self,
156        bound: impl std::ops::RangeBounds<u8>,
157    ) -> impl DoubleEndedIterator<Item = (u8, OpaqueNodePtr<Self::Key, Self::Value, PREFIX_LEN>)>
158           + FusedIterator {
159        {
160            match (bound.start_bound(), bound.end_bound()) {
161                (Bound::Excluded(s), Bound::Excluded(e)) if s == e => {
162                    panic!("range start and end are equal and excluded: ({s:?})")
163                },
164                (
165                    Bound::Included(s) | Bound::Excluded(s),
166                    Bound::Included(e) | Bound::Excluded(e),
167                ) if s > e => {
168                    panic!("range start ({s:?}) is greater than range end ({e:?})")
169                },
170                _ => {},
171            }
172        }
173
174        let start = bound.start_bound().map(|val| usize::from(*val));
175        let key_offset = match bound.start_bound() {
176            std::ops::Bound::Included(val) => *val,
177            std::ops::Bound::Excluded(val) => val.saturating_add(1),
178            std::ops::Bound::Unbounded => 0,
179        };
180        let end = bound.end_bound().map(|val| usize::from(*val));
181
182        self.child_pointers[(start, end)]
183            .iter()
184            .enumerate()
185            .filter_map(move |(key, child)| {
186                child.map(|child| ((key as u8).saturating_add(key_offset), child))
187            })
188    }
189
190    #[cfg(feature = "nightly")]
191    #[cfg_attr(test, mutants::skip)]
192    fn min(&self) -> (u8, OpaqueNodePtr<K, V, PREFIX_LEN>) {
193        use crate::rust_nightly_apis::assume;
194
195        // SAFETY: Due to niche optimization Option<NonNull> has the same
196        // size as NonNull and NonNull has the same size as usize
197        // so it's safe to transmute
198        let child_pointers: &[usize; 256] = unsafe { std::mem::transmute(&self.child_pointers) };
199        let empty = usizex64::splat(0);
200        let r0 = usizex64::from_array(child_pointers[0..64].try_into().unwrap())
201            .simd_eq(empty)
202            .to_bitmask();
203        let r1 = usizex64::from_array(child_pointers[64..128].try_into().unwrap())
204            .simd_eq(empty)
205            .to_bitmask();
206        let r2 = usizex64::from_array(child_pointers[128..192].try_into().unwrap())
207            .simd_eq(empty)
208            .to_bitmask();
209        let r3 = usizex64::from_array(child_pointers[192..256].try_into().unwrap())
210            .simd_eq(empty)
211            .to_bitmask();
212
213        let key = if r0 != u64::MAX {
214            r0.trailing_ones()
215        } else if r1 != u64::MAX {
216            r1.trailing_ones() + 64
217        } else if r2 != u64::MAX {
218            r2.trailing_ones() + 128
219        } else {
220            r3.trailing_ones() + 192
221        } as usize;
222
223        unsafe {
224            // SAFETY: key can be at up to 256, but we know that we have
225            // at least one inner child, it's guarantee to be in bounds
226            assume!(key < self.child_pointers.len());
227        }
228
229        // SAFETY: Covered by the containing function
230        (key as u8, unsafe {
231            self.child_pointers[key].unwrap_unchecked()
232        })
233    }
234
235    #[cfg(not(feature = "nightly"))]
236    fn min(&self) -> (u8, OpaqueNodePtr<K, V, PREFIX_LEN>) {
237        for (key, child_pointer) in self.child_pointers.iter().enumerate() {
238            match child_pointer {
239                Some(child_pointer) => return (key as u8, *child_pointer),
240                None => continue,
241            }
242        }
243        unreachable!("inner node must have non-zero number of children");
244    }
245
246    #[cfg(feature = "nightly")]
247    #[cfg_attr(test, mutants::skip)]
248    fn max(&self) -> (u8, OpaqueNodePtr<K, V, PREFIX_LEN>) {
249        use crate::rust_nightly_apis::assume;
250
251        // SAFETY: Due to niche optimization Option<NonNull> has the same
252        // size as NonNull and NonNull has the same size as usize
253        // so it's safe to transmute
254        let child_pointers: &[usize; 256] = unsafe { std::mem::transmute(&self.child_pointers) };
255        let empty = usizex64::splat(0);
256        let r0 = usizex64::from_array(child_pointers[0..64].try_into().unwrap())
257            .simd_eq(empty)
258            .to_bitmask();
259        let r1 = usizex64::from_array(child_pointers[64..128].try_into().unwrap())
260            .simd_eq(empty)
261            .to_bitmask();
262        let r2 = usizex64::from_array(child_pointers[128..192].try_into().unwrap())
263            .simd_eq(empty)
264            .to_bitmask();
265        let r3 = usizex64::from_array(child_pointers[192..256].try_into().unwrap())
266            .simd_eq(empty)
267            .to_bitmask();
268
269        let key = if r3 != u64::MAX {
270            255 - r3.leading_ones()
271        } else if r2 != u64::MAX {
272            191 - r2.leading_ones()
273        } else if r1 != u64::MAX {
274            127 - r1.leading_ones()
275        } else {
276            // SAFETY: This subtraction can't fail, because we know that
277            // we have at least one child, so the number of leading ones
278            // in this last case is <= 63
279            63 - r0.leading_ones()
280        } as usize;
281
282        unsafe {
283            // SAFETY: idx can be at up to 255, so it's in bounds
284            assume!(key < self.child_pointers.len());
285        }
286
287        // SAFETY: covered by the containing function
288        (key as u8, unsafe {
289            self.child_pointers[key].unwrap_unchecked()
290        })
291    }
292
293    #[cfg(not(feature = "nightly"))]
294    fn max(&self) -> (u8, OpaqueNodePtr<K, V, PREFIX_LEN>) {
295        for (key, child_pointer) in self.child_pointers.iter().enumerate().rev() {
296            match child_pointer {
297                Some(child_pointer) => return (key as u8, *child_pointer),
298                None => continue,
299            }
300        }
301        unreachable!("inner node must have non-zero number of children");
302    }
303}
304
305/// This struct is an iterator over the children of a [`InnerNode256`].
306#[cfg(not(feature = "nightly"))]
307pub struct Node256Iter<'a, K, V, const PREFIX_LEN: usize> {
308    pub(crate) it: Enumerate<Iter<'a, Option<OpaqueNodePtr<K, V, PREFIX_LEN>>>>,
309}
310
311#[cfg(not(feature = "nightly"))]
312impl<K, V, const PREFIX_LEN: usize> Iterator for Node256Iter<'_, K, V, PREFIX_LEN> {
313    type Item = (u8, OpaqueNodePtr<K, V, PREFIX_LEN>);
314
315    fn next(&mut self) -> Option<Self::Item> {
316        for (key, node) in self.it.by_ref() {
317            match node {
318                Some(node) => return Some((key as u8, *node)),
319                None => continue,
320            }
321        }
322        None
323    }
324}
325
326#[cfg(not(feature = "nightly"))]
327impl<K, V, const PREFIX_LEN: usize> DoubleEndedIterator for Node256Iter<'_, K, V, PREFIX_LEN> {
328    fn next_back(&mut self) -> Option<Self::Item> {
329        while let Some((key, node)) = self.it.next_back() {
330            match node {
331                Some(node) => return Some((key as u8, *node)),
332                None => continue,
333            }
334        }
335        None
336    }
337}
338
339#[cfg(not(feature = "nightly"))]
340impl<K, V, const PREFIX_LEN: usize> FusedIterator for Node256Iter<'_, K, V, PREFIX_LEN> {}
341
342#[cfg(test)]
343mod tests {
344    use std::ops::{Bound, RangeBounds};
345
346    use crate::raw::{
347        representation::tests::{
348            inner_node_min_max_test, inner_node_remove_child_test, inner_node_shrink_test,
349            inner_node_write_child_test, FixtureReturn,
350        },
351        LeafNode, NodePtr,
352    };
353
354    use super::*;
355
356    #[test]
357    fn lookup() {
358        let mut n = InnerNode256::<Box<[u8]>, (), 16>::empty();
359        let mut l1 = LeafNode::with_no_siblings(Box::from([]), ());
360        let mut l2 = LeafNode::with_no_siblings(Box::from([]), ());
361        let mut l3 = LeafNode::with_no_siblings(Box::from([]), ());
362        let l1_ptr = NodePtr::from(&mut l1).to_opaque();
363        let l2_ptr = NodePtr::from(&mut l2).to_opaque();
364        let l3_ptr = NodePtr::from(&mut l3).to_opaque();
365
366        assert!(n.lookup_child(123).is_none());
367
368        n.header.inc_num_children();
369        n.header.inc_num_children();
370        n.header.inc_num_children();
371
372        n.child_pointers[1] = Some(l1_ptr);
373        n.child_pointers[123] = Some(l2_ptr);
374        n.child_pointers[3] = Some(l3_ptr);
375
376        assert_eq!(n.lookup_child(123), Some(l2_ptr));
377    }
378
379    #[test]
380    fn write_child() {
381        inner_node_write_child_test(InnerNode256::<_, _, 16>::empty(), 256)
382    }
383
384    #[test]
385    fn remove_child() {
386        inner_node_remove_child_test(InnerNode256::<_, _, 16>::empty(), 256)
387    }
388
389    #[test]
390    #[should_panic = "unable to grow a Node256, something went wrong!"]
391    fn grow() {
392        let n = InnerNode256::<Box<[u8]>, (), 16>::empty();
393
394        n.grow();
395    }
396
397    #[test]
398    fn shrink() {
399        inner_node_shrink_test(InnerNode256::<_, _, 16>::empty(), 48);
400    }
401
402    #[test]
403    #[should_panic = "Cannot shrink a Node256 when it has more than 48 children. Currently has \
404                      [49] children."]
405    fn shrink_too_many_children_panic() {
406        inner_node_shrink_test(InnerNode256::<_, _, 16>::empty(), 49);
407    }
408
409    #[test]
410    fn min_max() {
411        inner_node_min_max_test(InnerNode256::<_, _, 16>::empty(), 256);
412    }
413
414    fn fixture() -> FixtureReturn<InnerNode256<Box<[u8]>, (), 16>, 4> {
415        let mut n256 = InnerNode256::empty();
416        let mut l1 = LeafNode::with_no_siblings(vec![].into(), ());
417        let mut l2 = LeafNode::with_no_siblings(vec![].into(), ());
418        let mut l3 = LeafNode::with_no_siblings(vec![].into(), ());
419        let mut l4 = LeafNode::with_no_siblings(vec![].into(), ());
420        let l1_ptr = NodePtr::from(&mut l1).to_opaque();
421        let l2_ptr = NodePtr::from(&mut l2).to_opaque();
422        let l3_ptr = NodePtr::from(&mut l3).to_opaque();
423        let l4_ptr = NodePtr::from(&mut l4).to_opaque();
424
425        n256.write_child(3, l1_ptr);
426        n256.write_child(255, l2_ptr);
427        n256.write_child(0u8, l3_ptr);
428        n256.write_child(85, l4_ptr);
429
430        (n256, [l1, l2, l3, l4], [l1_ptr, l2_ptr, l3_ptr, l4_ptr])
431    }
432
433    #[test]
434    fn iterate() {
435        let (node, _, [l1_ptr, l2_ptr, l3_ptr, l4_ptr]) = fixture();
436
437        let mut iter = node.iter();
438
439        assert_eq!(iter.next().unwrap(), (0u8, l3_ptr));
440        assert_eq!(iter.next().unwrap(), (3, l1_ptr));
441        assert_eq!(iter.next().unwrap(), (85, l4_ptr));
442        assert_eq!(iter.next().unwrap(), (255, l2_ptr));
443        assert_eq!(iter.next(), None);
444    }
445
446    #[test]
447    fn iterate_rev() {
448        let (node, _, [l1_ptr, l2_ptr, l3_ptr, l4_ptr]) = fixture();
449
450        let mut iter = node.iter().rev();
451
452        assert_eq!(iter.next().unwrap(), (255, l2_ptr));
453        assert_eq!(iter.next().unwrap(), (85, l4_ptr));
454        assert_eq!(iter.next().unwrap(), (3, l1_ptr));
455        assert_eq!(iter.next().unwrap(), (0u8, l3_ptr));
456        assert_eq!(iter.next(), None);
457    }
458
459    #[test]
460    fn range_iterate() {
461        let (node, _, [l1_ptr, l2_ptr, l3_ptr, l4_ptr]) = fixture();
462
463        #[track_caller]
464        fn check<K, V, const PREFIX_LEN: usize, const N: usize>(
465            node: &InnerNode256<K, V, PREFIX_LEN>,
466            bound: impl RangeBounds<u8>,
467            expected_pairs: [(u8, OpaqueNodePtr<K, V, PREFIX_LEN>); N],
468        ) {
469            let pairs = node.range(bound).collect::<Vec<_>>();
470            assert_eq!(pairs, expected_pairs);
471        }
472
473        check(
474            &node,
475            (Bound::Included(0), Bound::Included(3)),
476            [(0u8, l3_ptr), (3, l1_ptr)],
477        );
478        check(&node, (Bound::Excluded(0), Bound::Excluded(3)), []);
479        check(
480            &node,
481            (Bound::Included(0), Bound::Included(0)),
482            [(0u8, l3_ptr)],
483        );
484        check(
485            &node,
486            (Bound::Included(0), Bound::Included(255)),
487            [(0u8, l3_ptr), (3, l1_ptr), (85, l4_ptr), (255, l2_ptr)],
488        );
489        check(
490            &node,
491            (Bound::Included(255), Bound::Included(255)),
492            [(255, l2_ptr)],
493        );
494        check(&node, (Bound::Included(255), Bound::Excluded(255)), []);
495        check(&node, (Bound::Excluded(255), Bound::Included(255)), []);
496        check(
497            &node,
498            (Bound::Excluded(0), Bound::Excluded(255)),
499            [(3, l1_ptr), (85, l4_ptr)],
500        );
501        check(
502            &node,
503            (Bound::<u8>::Unbounded, Bound::Unbounded),
504            [(0u8, l3_ptr), (3, l1_ptr), (85, l4_ptr), (255, l2_ptr)],
505        );
506        check(
507            &node,
508            (Bound::<u8>::Unbounded, Bound::Included(86)),
509            [(0u8, l3_ptr), (3, l1_ptr), (85, l4_ptr)],
510        );
511    }
512
513    fn fixture_empty_edges() -> FixtureReturn<InnerNode256<Box<[u8]>, (), 16>, 4> {
514        let mut n4 = InnerNode256::empty();
515        let mut l1 = LeafNode::with_no_siblings(vec![].into(), ());
516        let mut l2 = LeafNode::with_no_siblings(vec![].into(), ());
517        let mut l3 = LeafNode::with_no_siblings(vec![].into(), ());
518        let mut l4 = LeafNode::with_no_siblings(vec![].into(), ());
519        let l1_ptr = NodePtr::from(&mut l1).to_opaque();
520        let l2_ptr = NodePtr::from(&mut l2).to_opaque();
521        let l3_ptr = NodePtr::from(&mut l3).to_opaque();
522        let l4_ptr = NodePtr::from(&mut l4).to_opaque();
523
524        n4.write_child(3, l1_ptr);
525        n4.write_child(254, l2_ptr);
526        n4.write_child(2u8, l3_ptr);
527        n4.write_child(85, l4_ptr);
528
529        (n4, [l1, l2, l3, l4], [l1_ptr, l2_ptr, l3_ptr, l4_ptr])
530    }
531
532    #[test]
533    fn range_iterate_boundary_conditions() {
534        let (node, _, [l1_ptr, l2_ptr, l3_ptr, l4_ptr]) = fixture_empty_edges();
535
536        #[track_caller]
537        fn check<K, V, const PREFIX_LEN: usize, const N: usize>(
538            node: &InnerNode256<K, V, PREFIX_LEN>,
539            bound: impl RangeBounds<u8>,
540            expected_pairs: [(u8, OpaqueNodePtr<K, V, PREFIX_LEN>); N],
541        ) {
542            let pairs = node.range(bound).collect::<Vec<_>>();
543            assert_eq!(pairs, expected_pairs);
544        }
545
546        check(
547            &node,
548            (Bound::<u8>::Unbounded, Bound::Included(86)),
549            [(2u8, l3_ptr), (3, l1_ptr), (85, l4_ptr)],
550        );
551        check(
552            &node,
553            (Bound::<u8>::Unbounded, Bound::Included(4)),
554            [(2u8, l3_ptr), (3, l1_ptr)],
555        );
556        check(
557            &node,
558            (Bound::<u8>::Unbounded, Bound::Excluded(3)),
559            [(2u8, l3_ptr)],
560        );
561        check(
562            &node,
563            (Bound::<u8>::Unbounded, Bound::Included(2)),
564            [(2u8, l3_ptr)],
565        );
566        check(&node, (Bound::<u8>::Unbounded, Bound::Included(1)), []);
567        check(&node, (Bound::<u8>::Unbounded, Bound::Included(0)), []);
568
569        check(
570            &node,
571            (Bound::Included(1), Bound::<u8>::Unbounded),
572            [(2u8, l3_ptr), (3, l1_ptr), (85, l4_ptr), (254, l2_ptr)],
573        );
574        check(
575            &node,
576            (Bound::Included(3), Bound::<u8>::Unbounded),
577            [(3, l1_ptr), (85, l4_ptr), (254, l2_ptr)],
578        );
579        check(
580            &node,
581            (Bound::Excluded(84), Bound::<u8>::Unbounded),
582            [(85, l4_ptr), (254, l2_ptr)],
583        );
584        check(
585            &node,
586            (Bound::Included(253), Bound::<u8>::Unbounded),
587            [(254, l2_ptr)],
588        );
589        check(&node, (Bound::Included(255), Bound::<u8>::Unbounded), []);
590    }
591
592    #[test]
593    #[should_panic = "range start and end are equal and excluded: (80)"]
594    fn range_iterate_out_of_bounds_panic_both_excluded() {
595        let (node, _, [_l1_ptr, _l2_ptr, _l3_ptr, _l4_ptr]) = fixture();
596
597        let pairs = node
598            .range((Bound::Excluded(80), Bound::Excluded(80)))
599            .collect::<Vec<_>>();
600        assert_eq!(pairs, &[]);
601    }
602
603    #[test]
604    #[should_panic = "range start (80) is greater than range end (0)"]
605    fn range_iterate_start_greater_than_end() {
606        let (node, _, [_l1_ptr, _l2_ptr, _l3_ptr, _l4_ptr]) = fixture();
607
608        let _pairs = node
609            .range((Bound::Excluded(80), Bound::Included(0)))
610            .collect::<Vec<_>>();
611    }
612}