summaryrefslogtreecommitdiff
path: root/crates/alloc_buddy/src/tree.rs
blob: 787af21f5c520361e88d35b3845cd4566749fc66 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
use crate::bitset::{Bitset, SubregionStatus, UnexpectedBitsetState};
use contracts::requires;
use core::{fmt, ptr::NonNull};

/// A single region of the allocator. See the comment on the `crate::allocators::buddy` module for
/// more information.
///
/// This type is valid when zero-initialized.
#[derive(Debug)]
pub struct Tree<const PAGE_SIZE: usize, const PAGE_SIZE_BITS: usize> {
    /// The base address of the tree.
    pub base_ptr: Option<NonNull<[u8; PAGE_SIZE]>>,

    /// The log2 of the number of pages in the region represented by the tree.
    pub size_class: usize,

    /// The offset in the bitset to the bits responsible for this tree's pages.
    pub bitset_offset: usize,
}

impl<const PAGE_SIZE: usize, const PAGE_SIZE_BITS: usize> Tree<PAGE_SIZE, PAGE_SIZE_BITS> {
    /// Returns the base address of the tree.
    #[requires(self.base_ptr.is_some())]
    pub fn base_addr(&self) -> usize {
        self.base_ptr.unwrap().as_ptr() as usize
    }

    /// Reads a bit from the bitset.
    #[requires(size_class <= self.size_class)]
    #[requires(offset_bytes < (PAGE_SIZE << (self.size_class + 1)))]
    #[requires(offset_bytes.trailing_zeros() as usize >= PAGE_SIZE_BITS + size_class)]
    pub fn bitset_get(
        &self,
        bitset: &Bitset,
        size_class: usize,
        offset_bytes: usize,
    ) -> SubregionStatus {
        bitset.get(self.bitset_index(size_class, offset_bytes))
    }

    /// Returns the index of a bit in the bitset that corresponds to the given size class and
    /// offset.
    #[requires(size_class <= self.size_class)]
    #[requires(offset_bytes < (PAGE_SIZE << (self.size_class + 1)))]
    #[requires(offset_bytes.trailing_zeros() as usize >= PAGE_SIZE_BITS + size_class)]
    fn bitset_index(&self, size_class: usize, offset_bytes: usize) -> usize {
        // We store the largest size classes first in the bitset. Count how many we are away from
        // the largest.
        let skipped_size_classes = self.size_class - size_class;
        let bits_skipped_for_size_class = (1 << skipped_size_classes) - 1;

        // Next, find our index in the size class.
        let bits_skipped_for_index = offset_bytes >> (PAGE_SIZE_BITS + size_class);

        // The sum of those two with our offset is simply our index.
        self.bitset_offset + bits_skipped_for_size_class + bits_skipped_for_index
    }

    /// Changes a bit in the bitset from `InFreeList` to `NotInFreeList`.
    #[requires(size_class <= self.size_class)]
    #[requires(offset_bytes < (PAGE_SIZE << (self.size_class + 1)))]
    #[requires(offset_bytes.trailing_zeros() as usize >= PAGE_SIZE_BITS + size_class)]
    pub fn bitset_mark_as_absent(
        &self,
        bitset: &mut Bitset,
        size_class: usize,
        offset_bytes: usize,
    ) -> Result<(), UnexpectedBitsetState> {
        bitset.replace(
            self.bitset_index(size_class, offset_bytes),
            SubregionStatus::InFreeList,
            SubregionStatus::NotInFreeList,
        )
    }

    /// Changes a bit in the bitset from `NotInFreeList` to `InFreeList`.
    #[requires(size_class <= self.size_class)]
    #[requires(offset_bytes < (PAGE_SIZE << (self.size_class + 1)))]
    #[requires(offset_bytes.trailing_zeros() as usize >= PAGE_SIZE_BITS + size_class)]
    pub fn bitset_mark_as_present(
        &self,
        bitset: &mut Bitset,
        size_class: usize,
        offset_bytes: usize,
    ) -> Result<(), UnexpectedBitsetState> {
        bitset.replace(
            self.bitset_index(size_class, offset_bytes),
            SubregionStatus::NotInFreeList,
            SubregionStatus::InFreeList,
        )
    }

    /// Returns whether the tree contains an address.
    #[requires(self.base_ptr.is_some())]
    pub fn contains(&self, addr: usize) -> bool {
        let tree_addr_lo = self.base_addr();
        let tree_addr_hi = tree_addr_lo + (PAGE_SIZE << self.size_class);
        (tree_addr_lo..tree_addr_hi).contains(&addr)
    }

    /// Formats the region of the bitset corresponding to this tree.
    pub fn debug_bitset<'a>(&'a self, bitset: &'a Bitset) -> impl 'a + fmt::Debug {
        struct BitsetSlice<'a, const PAGE_SIZE: usize, const PAGE_SIZE_BITS: usize>(
            &'a Tree<PAGE_SIZE, PAGE_SIZE_BITS>,
            &'a Bitset,
            usize,
        );

        impl<'a, const PAGE_SIZE: usize, const PAGE_SIZE_BITS: usize> fmt::Debug
            for BitsetSlice<'a, PAGE_SIZE, PAGE_SIZE_BITS>
        {
            fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
                for i in 0..(1 << (self.0.size_class - self.2)) {
                    let offset_bytes = i << (PAGE_SIZE_BITS + self.2);
                    let bit = match self.0.bitset_get(self.1, self.2, offset_bytes) {
                        SubregionStatus::NotInFreeList => '0',
                        SubregionStatus::InFreeList => '1',
                    };
                    write!(fmt, "{}", bit)?;
                    for _ in 0..(1 << self.2) - 1 {
                        write!(fmt, " ")?;
                    }
                }
                Ok(())
            }
        }

        struct BitsetTree<'a, const PAGE_SIZE: usize, const PAGE_SIZE_BITS: usize>(
            &'a Tree<PAGE_SIZE, PAGE_SIZE_BITS>,
            &'a Bitset,
        );

        impl<'a, const PAGE_SIZE: usize, const PAGE_SIZE_BITS: usize> fmt::Debug
            for BitsetTree<'a, PAGE_SIZE, PAGE_SIZE_BITS>
        {
            fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
                fmt.debug_list()
                    .entries(
                        (0..=self.0.size_class)
                            .rev()
                            .map(|size_class| BitsetSlice(self.0, self.1, size_class)),
                    )
                    .finish()
            }
        }

        BitsetTree(self, bitset)
    }
}