Reimplement AVL trees
[monolithium.git] / sdk / avltree.h
1 /*
2  * avltree.h
3  *
4  * Copyright (C) 2018 Aleksandar Andrejevic <theflash@sdf.lonestar.org>
5  *
6  * This program is free software: you can redistribute it and/or modify
7  * it under the terms of the GNU Affero General Public License as
8  * published by the Free Software Foundation, either version 3 of the
9  * License, or (at your option) any later version.
10  *
11  * This program is distributed in the hope that it will be useful,
12  * but WITHOUT ANY WARRANTY; without even the implied warranty of
13  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14  * GNU Affero General Public License for more details.
15  *
16  * You should have received a copy of the GNU Affero General Public License
17  * along with this program.  If not, see <http://www.gnu.org/licenses/>.
18  */
19
20 #ifndef __MONOLITHIUM_AVLTREE_H__
21 #define __MONOLITHIUM_AVLTREE_H__
22
23 #include "defs.h"
24
25 #define AVL_TREE_INIT(t, s, n, k, c) avl_tree_init((t), (ptrdiff_t)&((s*)NULL)->k - (ptrdiff_t)&((s*)NULL)->n, sizeof(((s*)NULL)->k), c)
26
27 typedef int (*avl_compare_proc_t)(const void *key1, const void *key2);
28
29 typedef struct avl_node
30 {
31     struct avl_node *parent;
32     struct avl_node *left;
33     struct avl_node *right;
34     struct avl_node *next_equal;
35     struct avl_node *prev_equal;
36     int balance;
37 } avl_node_t;
38
39 typedef struct avl_tree
40 {
41     avl_node_t *root;
42     ptrdiff_t key_offset;
43     size_t key_size;
44     avl_compare_proc_t compare;
45 } avl_tree_t;
46
47 static inline void avl_tree_init(avl_tree_t *tree, ptrdiff_t key_offset, size_t key_size, avl_compare_proc_t compare)
48 {
49     tree->root = NULL;
50     tree->key_offset = key_offset;
51     tree->key_size = key_size;
52     tree->compare = compare;
53 }
54
55 static inline void *avl_get_keyptr(const avl_tree_t *tree, const avl_node_t *node)
56 {
57     return (void*)((ptrdiff_t)node + tree->key_offset);
58 }
59
60 static inline avl_node_t *avl_tree_lookup(const avl_tree_t *tree, const void *key)
61 {
62     avl_node_t *node = tree->root;
63
64     while (node)
65     {
66         const void *node_key = avl_get_keyptr(tree, node);
67         int comparison = tree->compare(key, node_key);
68
69         if (comparison == 0) return node;
70         else if (comparison < 0) node = node->left;
71         else node = node->right;
72     }
73
74     return NULL;
75 }
76
77 static inline avl_node_t *avl_tree_lower_bound(const avl_tree_t *tree, const void *key)
78 {
79     avl_node_t *node = tree->root;
80
81     while (node && tree->compare(avl_get_keyptr(tree, node), key) > 0) node = node->left;
82     if (!node) return NULL;
83
84     while (node->right && tree->compare(avl_get_keyptr(tree, node->right), key) <= 0) node = node->right;
85     return node;
86 }
87
88 static inline avl_node_t *avl_tree_upper_bound(const avl_tree_t *tree, const void *key)
89 {
90     avl_node_t *node = tree->root;
91
92     while (node && tree->compare(avl_get_keyptr(tree, node), key) < 0) node = node->right;
93     if (!node) return NULL;
94
95     while (node->left && tree->compare(avl_get_keyptr(tree, node->left), key) >= 0) node = node->left;
96     return node;
97 }
98
99 static inline avl_node_t *avl_get_next_node(const avl_node_t *node)
100 {
101     while (node->prev_equal) node = node->prev_equal;
102
103     if (node->right)
104     {
105         node = node->right;
106         while (node->left) node = node->left;
107     }
108     else
109     {
110         while (node->parent && node->parent->right == node) node = node->parent;
111         node = node->parent;
112     }
113
114     return (avl_node_t*)node;
115 }
116
117 static inline avl_node_t *avl_get_previous_node(const avl_node_t *node)
118 {
119     while (node->prev_equal) node = node->prev_equal;
120
121     if (node->left)
122     {
123         node = node->left;
124         while (node->right) node = node->right;
125     }
126     else
127     {
128         while (node->parent && node->parent->left == node) node = node->parent;
129         node = node->parent;
130     }
131
132     return (avl_node_t*)node;
133 }
134
135 static inline avl_node_t *avl_rotate_left(avl_tree_t *tree, avl_node_t *root)
136 {
137     avl_node_t *pivot = root->right;
138     root->right = pivot->left;
139     if (root->right) root->right->parent = root;
140
141     pivot->parent = root->parent;
142     pivot->left = root;
143     root->parent = pivot;
144
145     if (pivot->parent)
146     {
147         if (pivot->parent->left == root) pivot->parent->left = pivot;
148         else if (pivot->parent->right == root) pivot->parent->right = pivot;
149     }
150     else
151     {
152         tree->root = pivot;
153     }
154
155     root->balance -= pivot->balance > 0 ? pivot->balance + 1 : 1;
156     pivot->balance += root->balance < 0 ? root->balance - 1 : -1;
157     return pivot;
158 }
159
160 static inline avl_node_t *avl_rotate_right(avl_tree_t *tree, avl_node_t *root)
161 {
162     avl_node_t *pivot = root->left;
163     root->left = pivot->right;
164     if (root->left) root->left->parent = root;
165
166     pivot->parent = root->parent;
167     pivot->right = root;
168     root->parent = pivot;
169
170     if (pivot->parent)
171     {
172         if (pivot->parent->left == root) pivot->parent->left = pivot;
173         else if (pivot->parent->right == root) pivot->parent->right = pivot;
174     }
175     else
176     {
177         tree->root = pivot;
178     }
179
180     root->balance -= pivot->balance < 0 ? pivot->balance - 1 : -1;
181     pivot->balance += root->balance > 0 ? root->balance + 1 : 1;
182     return pivot;
183 }
184
185 static void avl_tree_insert(avl_tree_t *tree, avl_node_t *node)
186 {
187     node->left = node->right = node->parent = node->next_equal = node->prev_equal = NULL;
188     node->balance = 0;
189
190     if (!tree->root)
191     {
192         tree->root = node;
193         return;
194     }
195
196     avl_node_t *current = tree->root;
197     const void *node_key = avl_get_keyptr(tree, node);
198
199     while (TRUE)
200     {
201         const void *key = avl_get_keyptr(tree, current);
202         int comparison = tree->compare(node_key, key);
203
204         if (comparison == 0)
205         {
206             while (current->next_equal) current = current->next_equal;
207             current->next_equal = node;
208             node->prev_equal = current;
209             return;
210         }
211         else if (comparison < 0)
212         {
213             if (!current->left)
214             {
215                 node->parent = current;
216                 current->left = node;
217                 break;
218             }
219             else
220             {
221                 current = current->left;
222             }
223         }
224         else
225         {
226             if (!current->right)
227             {
228                 node->parent = current;
229                 current->right = node;
230                 break;
231             }
232             else
233             {
234                 current = current->right;
235             }
236         }
237     }
238
239     while (current)
240     {
241         if (node == current->left) current->balance--;
242         else current->balance++;
243
244         if (current->balance == 0) break;
245
246         if (current->balance < -1)
247         {
248             if (node->balance > 0) avl_rotate_left(tree, current->left);
249             current = avl_rotate_right(tree, current);
250             break;
251         }
252         else if (current->balance > 1)
253         {
254             if (node->balance < 0) avl_rotate_right(tree, current->right);
255             current = avl_rotate_left(tree, current);
256             break;
257         }
258
259         node = current;
260         current = current->parent;
261     }
262 }
263
264 static void avl_tree_remove(avl_tree_t *tree, avl_node_t *node)
265 {
266     if (node->prev_equal)
267     {
268         node->prev_equal->next_equal = node->next_equal;
269         if (node->next_equal) node->next_equal->prev_equal = node->prev_equal;
270         node->next_equal = node->prev_equal = NULL;
271         return;
272     }
273     else if (node->next_equal)
274     {
275         node->next_equal->parent = node->parent;
276         node->next_equal->left = node->left;
277         node->next_equal->right = node->right;
278         node->next_equal->prev_equal = NULL;
279
280         if (node->parent)
281         {
282             if (node->parent->left == node) node->parent->left = node->next_equal;
283             else node->parent->right = node->next_equal;
284         }
285         else
286         {
287             tree->root = node->next_equal;
288         }
289
290         if (node->left) node->left->parent = node->next_equal;
291         if (node->right) node->right->parent = node->next_equal;
292
293         node->parent = node->left = node->right = node->next_equal = NULL;
294         node->balance = 0;
295         return;
296     }
297
298     if (node->left && node->right)
299     {
300         avl_node_t *replacement = node->right;
301
302         if (replacement->left)
303         {
304             while (replacement->left) replacement = replacement->left;
305
306             avl_node_t *temp_parent = replacement->parent;
307             avl_node_t *temp_right = replacement->right;
308             int temp_balance = replacement->balance;
309
310             replacement->parent = node->parent;
311             replacement->left = node->left;
312             replacement->right = node->right;
313             replacement->balance = node->balance;
314
315             if (replacement->parent)
316             {
317                 if (replacement->parent->left == node) replacement->parent->left = replacement;
318                 else replacement->parent->right = replacement;
319             }
320             else
321             {
322                 tree->root = replacement;
323             }
324
325             if (replacement->left) replacement->left->parent = replacement;
326             if (replacement->right) replacement->right->parent = replacement;
327
328             node->parent = temp_parent;
329             node->left = NULL;
330             node->right = temp_right;
331             node->balance = temp_balance;
332
333             if (node->parent->left == replacement) node->parent->left = node;
334             else node->parent->right = node;
335
336             if (node->right) node->right->parent = node;
337         }
338         else
339         {
340             avl_node_t *temp_right = replacement->right;
341             int temp_balance = replacement->balance;
342
343             replacement->parent = node->parent;
344             replacement->left = node->left;
345             replacement->right = node;
346             replacement->balance = node->balance;
347
348             if (replacement->parent)
349             {
350                 if (replacement->parent->left == node) replacement->parent->left = replacement;
351                 else replacement->parent->right = replacement;
352             }
353             else
354             {
355                 tree->root = replacement;
356             }
357
358             if (replacement->left) replacement->left->parent = replacement;
359
360             node->parent = replacement;
361             node->left = NULL;
362             node->right = temp_right;
363             node->balance = temp_balance;
364
365             if (node->right) node->right->parent = node;
366         }
367     }
368
369     avl_node_t *current = node->parent;
370     bool_t left_child;
371
372     if (current)
373     {
374         left_child = current->left == node;
375
376         if (left_child)
377         {
378             current->left = node->left ? node->left : node->right;
379             if (current->left) current->left->parent = current;
380         }
381         else
382         {
383             current->right = node->left ? node->left : node->right;
384             if (current->right) current->right->parent = current;
385         }
386     }
387     else
388     {
389         tree->root = node->left ? node->left : node->right;
390         if (tree->root) tree->root->parent = NULL;
391     }
392
393     node->parent = node->left = node->right = NULL;
394     node->balance = 0;
395
396     while (current)
397     {
398         if (left_child) current->balance++;
399         else current->balance--;
400
401         if (current->balance == 1 || current->balance == -1) break;
402
403         if (current->balance < -1)
404         {
405             int balance = current->left->balance;
406             if (balance > 0) avl_rotate_left(tree, current->left);
407             current = avl_rotate_right(tree, current);
408             if (balance == 0) break;
409         }
410         else if (current->balance > 1)
411         {
412             int balance = current->right->balance;
413             if (balance < 0) avl_rotate_right(tree, current->right);
414             current = avl_rotate_left(tree, current);
415             if (balance == 0) break;
416         }
417
418         node = current;
419         current = current->parent;
420         if (current) left_child = current->left == node;
421     }
422 }
423
424 static inline void avl_tree_change_key(avl_tree_t *tree, avl_node_t *node, const void *new_key)
425 {
426     avl_tree_remove(tree, node);
427     __builtin_memcpy(avl_get_keyptr(tree, node), new_key, tree->key_size);
428     avl_tree_insert(tree, node);
429 }
430
431 #endif