Self-Balancing Tree Structures
Self-balancing binary search trees maintain logaritmhic height during insertions and deletions. This ensures efficietn search, insertion, and deletion operations. Below are implementations for three common variants: SBT, Treap, and Splay trees.
Size Balanced Tree (SBT)
#include <iostream>
#include <cstdlib>
#define MAX_NODES 100000
struct SBTNode {
int value;
int size;
SBTNode *left, *right;
};
SBTNode nodePool[MAX_NODES];
int nodeCount = 0;
SBTNode *root = nullptr;
void leftRotate(SBTNode* &x) {
SBTNode *y = x->right;
x->right = y->left;
y->left = x;
y->size = x->size;
x->size = (x->left ? x->left->size : 0) + (x->right ? x->right->size : 0) + 1;
x = y;
}
void rightRotate(SBTNode* &x) {
SBTNode *y = x->left;
x->left = y->right;
y->right = x;
y->size = x->size;
x->size = (x->left ? x->left->size : 0) + (x->right ? x->right->size : 0) + 1;
x = y;
}
void maintain(SBTNode* &t, bool flag) {
if (!flag) {
if (t->left && t->left->left && (!t->right || t->left->left->size > t->right->size)) {
rightRotate(t);
} else if (t->left && t->left->right && (!t->right || t->left->right->size > t->right->size)) {
leftRotate(t->left);
rightRotate(t);
} else {
return;
}
} else {
if (t->right && t->right->right && (!t->left || t->right->right->size > t->left->size)) {
leftRotate(t);
} else if (t->right && t->right->left && (!t->left || t->right->left->size > t->left->size)) {
rightRotate(t->right);
leftRotate(t);
} else {
return;
}
}
maintain(t->left, false);
maintain(t->right, true);
}
void insert(SBTNode* &t, int value) {
if (!t) {
t = &nodePool[nodeCount++];
t->value = value;
t->size = 1;
t->left = t->right = nullptr;
return;
}
t->size++;
if (value < t->value) {
insert(t->left, value);
} else {
insert(t->right, value);
}
maintain(t, value >= t->value);
}
int remove(SBTNode* &t, int value) {
t->size--;
if ((value == t->value) || (value < t->value && !t->left) || (value > t->value && !t->right)) {
int result = t->value;
if (!t->left || !t->right) {
t = t->left ? t->left : t->right;
} else {
t->value = remove(t->left, t->value + 1);
}
return result;
}
if (value < t->value) {
return remove(t->left, value);
} else {
return remove(t->right, value);
}
}
int predecessor(SBTNode* t, int value) {
if (!t) return value;
if (t->value >= value) {
return predecessor(t->left, value);
} else {
int result = predecessor(t->right, value);
return result == value ? t->value : result;
}
}
int successor(SBTNode* t, int value) {
if (!t) return value;
if (t->value <= value) {
return successor(t->right, value);
} else {
int result = successor(t->left, value);
return result == value ? t->value : result;
}
}
int getRank(SBTNode* t, int value) {
if (!t) return 1;
if (value <= t->value) {
return getRank(t->left, value);
} else {
return (t->left ? t->left->size : 0) + 1 + getRank(t->right, value);
}
}
int selectKth(SBTNode* t, int k) {
int leftSize = t->left ? t->left->size : 0;
if (k <= leftSize) {
return selectKth(t->left, k);
} else if (k == leftSize + 1) {
return t->value;
} else {
return selectKth(t->right, k - leftSize - 1);
}
}
void inorder(SBTNode* t) {
if (!t) return;
inorder(t->left);
std::cout << t->value << " ";
inorder(t->right);
}
int main() {
int n, op, value;
std::cin >> n;
while (n--) {
std::cin >> op >> value;
switch (op) {
case 1: insert(root, value); break;
case 2: remove(root, value); break;
case 3: std::cout << selectKth(root, value) << std::endl; break;
}
}
inorder(root);
return 0;
}
Treap Implementation
#include <iostream>
#include <cstdlib>
#define MAX_NODES 500000
struct TreapNode {
int key, priority, size;
TreapNode *left, *right;
TreapNode(int k, int p) : key(k), priority(p), size(1), left(nullptr), right(nullptr) {}
};
TreapNode *root = nullptr;
int nodeCount = 0;
void updateSize(TreapNode* node) {
if (node) {
node->size = 1;
if (node->left) node->size += node->left->size;
if (node->right) node->size += node->right->size;
}
}
void rotate(TreapNode* &parent, bool direction) { // direction 0: left, 1: right
TreapNode *child = direction ? parent->right : parent->left;
if (direction) {
parent->right = child->left;
child->left = parent;
} else {
parent->left = child->right;
child->right = parent;
}
updateSize(parent);
updateSize(child);
parent = child;
}
void insert(TreapNode* &node, int key) {
if (!node) {
node = new TreapNode(key, rand());
return;
}
node->size++;
if (key < node->key) {
insert(node->left, key);
if (node->left->priority > node->priority) {
rotate(node, true);
}
} else {
insert(node->right, key);
if (node->right->priority > node->priority) {
rotate(node, false);
}
}
}
void remove(TreapNode* &node, int key) {
if (!node) return;
node->size--;
if (key == node->key) {
if (!node->left && !node->right) {
delete node;
node = nullptr;
} else if (!node->left || (node->right && node->left && node->right->priority > node->left->priority)) {
rotate(node, false);
remove(node->left, key);
} else {
rotate(node, true);
remove(node->right, key);
}
} else if (key < node->key) {
remove(node->left, key);
} else {
remove(node->right, key);
}
}
int getKth(TreapNode* node, int k) {
int leftSize = node->left ? node->left->size : 0;
if (k <= leftSize) {
return getKth(node->left, k);
} else if (k == leftSize + 1) {
return node->key;
} else {
return getKth(node->right, k - leftSize - 1);
}
}
int getRank(TreapNode* node, int key) {
if (!node) return 0;
if (key < node->key) {
return getRank(node->left, key);
} else {
int leftSize = node->left ? node->left->size : 0;
return leftSize + 1 + getRank(node->right, key);
}
}
int predecessor(TreapNode* node, int key) {
if (!node) return key;
if (key <= node->key) {
return predecessor(node->left, key);
} else {
int candidate = predecessor(node->right, key);
return candidate == key ? node->key : candidate;
}
}
int successor(TreapNode* node, int key) {
if (!node) return key;
if (key >= node->key) {
return successor(node->right, key);
} else {
int candidate = successor(node->left, key);
return candidate == key ? node->key : candidate;
}
}
int main() {
int n, op, value;
std::cin >> n;
while (n--) {
std::cin >> op >> value;
switch (op) {
case 1: insert(root, value); break;
case 2: remove(root, value); break;
case 3: std::cout << getRank(root, value) << std::endl; break;
case 4: std::cout << getKth(root, value) << std::endl; break;
case 5: std::cout << predecessor(root, value) << std::endl; break;
case 6: std::cout << successor(root, value) << std::endl; break;
}
}
return 0;
}
Splay Tree Implementation
#include <iostream>
#include <cstdlib>
struct SplayNode {
int key, count, size;
SplayNode *left, *right, *parent;
SplayNode(int k) : key(k), count(1), size(1), left(nullptr), right(nullptr), parent(nullptr) {}
};
SplayNode *root = nullptr;
void update(SplayNode *node) {
if (node) {
node->size = node->count;
if (node->left) node->size += node->left->size;
if (node->right) node->size += node->right->size;
}
}
void rotate(SplayNode *child) {
SplayNode *parent = child->parent;
if (!parent) return;
SplayNode *grandparent = parent->parent;
bool isLeftChild = (parent->left == child);
if (isLeftChild) {
parent->left = child->right;
if (child->right) child->right->parent = parent;
child->right = parent;
} else {
parent->right = child->left;
if (child->left) child->left->parent = parent;
child->left = parent;
}
parent->parent = child;
child->parent = grandparent;
if (grandparent) {
if (grandparent->left == parent) {
grandparent->left = child;
} else {
grandparent->right = child;
}
}
update(parent);
update(child);
}
void splay(SplayNode *node) {
while (node->parent) {
SplayNode *parent = node->parent;
SplayNode *grandparent = parent->parent;
if (!grandparent) {
rotate(node);
} else if ((grandparent->left == parent) == (parent->left == node)) {
rotate(parent);
rotate(node);
} else {
rotate(node);
rotate(node);
}
}
root = node;
}
void insert(int key) {
if (!root) {
root = new SplayNode(key);
return;
}
SplayNode *current = root, *parent = nullptr;
while (current) {
parent = current;
if (key == current->key) {
current->count++;
update(current);
splay(current);
return;
} else if (key < current->key) {
current = current->left;
} else {
current = current->right;
}
}
current = new SplayNode(key);
current->parent = parent;
if (key < parent->key) {
parent->left = current;
} else {
parent->right = current;
}
splay(current);
}
SplayNode* find(int key) {
SplayNode *current = root;
while (current) {
if (key == current->key) {
break;
} else if (key < current->key) {
current = current->left;
} else {
current = current->right;
}
}
if (current) splay(current);
return current;
}
void remove(int key) {
SplayNode *node = find(key);
if (!node) return;
if (node->count > 1) {
node->count--;
update(node);
return;
}
if (!node->left && !node->right) {
root = nullptr;
} else if (!node->left) {
root = node->right;
root->parent = nullptr;
} else if (!node->right) {
root = node->left;
root->parent = nullptr;
} else {
SplayNode *successor = node->right;
while (successor->left) successor = successor->left;
if (successor->parent != node) {
successor->parent->left = successor->right;
if (successor->right) successor->right->parent = successor->parent;
successor->right = node->right;
node->right->parent = successor;
}
successor->left = node->left;
node->left->parent = successor;
successor->parent = nullptr;
update(successor);
root = successor;
}
delete node;
}
int getRank(int key) {
SplayNode *node = find(key);
if (!node) return 0;
return node->left ? node->left->size + 1 : 1;
}
int selectKth(int k) {
SplayNode *current = root;
while (current) {
int leftSize = current->left ? current->left->size : 0;
if (k <= leftSize) {
current = current->left;
} else if (k <= leftSize + current->count) {
break;
} else {
k -= leftSize + current->count;
current = current->right;
}
}
if (current) {
splay(current);
return current->key;
}
return -1;
}
int predecessor(int key) {
find(key);
if (!root) return key;
if (root->key < key) return root->key;
SplayNode *current = root->left;
if (!current) return key;
while (current->right) current = current->right;
return current ? current->key : key;
}
int successor(int key) {
find(key);
if (!root) return key;
if (root->key > key) return root->key;
SplayNode *current = root->right;
if (!current) return key;
while (current->left) current = current->left;
return current ? current->key : key;
}
int main() {
int n, op, value;
std::cin >> n;
while (n--) {
std::cin >> op >> value;
switch (op) {
case 1: insert(value); break;
case 2: remove(value); break;
case 3: std::cout << getRank(value) << std::endl; break;
case 4: std::cout << selectKth(value) << std::endl; break;
case 5: std::cout << predecessor(value) << std::endl; break;
case 6: std::cout << successor(value) << std::endl; break;
}
}
return 0;
}