红黑树实现map和set
为了运用红黑树的一套工具,所以我们把红黑树设置为3个模版
第一个表示key,第二个表示key/key_value,第三个表示仿函数来解决key直接的比较
同时添加了一个迭代器iterator来作为迭代器
BRTree.h:
#pragma once
#include <iostream>
#include <vector>
#include <assert.h>
using namespace std;
enum color {
RED,
BLACK
};
template<class T>
struct RBTreeNode {
RBTreeNode<T>* _left;
RBTreeNode<T>* _right;
RBTreeNode<T>* _parent;
T _data;
color _col;
RBTreeNode(const T& data)
:_left(nullptr)
, _right(nullptr)
, _parent(nullptr)
, _data(data)
, _col(RED)
{}
};
template<class T, class Ptr, class Ref>
struct RBTreeIterator {
typedef RBTreeNode<T> Node;
typedef RBTreeIterator<T, Ptr, Ref> Self;
Node* _node;
RBTreeIterator(Node* node)
:_node(node)
{}
Ptr operator->() {
return &_node->_data;
}
Ref operator*() {
return _node->_data;
}
Self& operator++() {
if (_node->_right) {
Node* SubLeft = _node->_right;
while (SubLeft->_left)
SubLeft = SubLeft->_left;
_node = SubLeft;
}
else {
Node* cur = _node;
Node* parent = cur->_parent;
while (parent && cur == parent->_right) {
cur = parent;
parent = cur->_parent;
}
_node = parent;
}
return *this;
}
Self& operator--() {
if (_node->_left) {
Node* SubRight = _node->_left;
while (SubRight->_right)
SubRight = SubRight->_right;
_node = SubRight;
}
else {
Node* cur = _node;
Node* parent = _node->_parent;
while (parent && cur == parent->_left) {
cur = parent;
parent = parent->_parent;
}
_node = parent;
}
return *this;
}
bool operator!=(const Self& s) {
return _node != s._node;
}
bool operator==(const Self& s) {
return _node == s._node;
}
};
template<class K, class T, class KeyOfT>
struct RBTree {
public:
typedef RBTreeNode<T> Node;
typedef RBTreeIterator<T, T*, T&> iterator;
typedef RBTreeIterator<T, const T*, const T&> const_iterator;
const_iterator begin() const {
Node* SubLeft = _root;
while (SubLeft && SubLeft->_left)
SubLeft = SubLeft->_left;
return const_iterator(SubLeft);
}
const_iterator end() const {
const_iterator(nullptr);
}
iterator begin() {
Node* SubLeft = _root;
while (SubLeft && SubLeft->_left)
SubLeft = SubLeft->_left;
return iterator(SubLeft);
}
iterator end() {
return iterator(nullptr);
}
iterator Find(const K& key) {
KeyOfT kot;
Node* cur = _root;
while (cur) {
if (kot(cur->_data) < key)
cur = cur->_right;
else if (kot(cur->_data) > key)
cur = cur->_left;
else
return iterator(cur);
}
return end();
}
pair<iterator, bool> Insert(const T& data) {
if (_root == nullptr) {
_root = new Node(data);
_root->_col = BLACK;
return make_pair(iterator(_root), true);
}
KeyOfT kot;
Node* cur = _root;
Node* parent = nullptr;
while (cur) {
if (kot(cur->_data) > kot(data)) {
parent = cur;
cur = cur->_left;
}
else if (kot(cur->_data) < kot(data)) {
parent = cur;
cur = cur->_right;
}
else
return make_pair(iterator(cur), false);
}
cur = new Node(data);
Node* newnode = cur;
if (kot(parent->_data) < kot(data))
parent->_right = cur;
else
parent->_left = cur;
cur->_parent = parent;
while (parent && parent->_col == RED) {
Node* grandfather = parent->_parent;
if (parent == grandfather->_left) {
Node* uncle = grandfather->_right;
if (uncle && uncle->_col == RED) {
// 情况一:uncle存在且为红
parent->_col = uncle->_col = BLACK;
grandfather->_col = RED;
cur = grandfather;
parent = cur->_parent;
}
else {
//情况二:u不存在或者u为黑
if (parent->_left == cur) {
// g
// p u
// c
RotateR(grandfather);
parent->_col = BLACK;
grandfather->_col = RED;
}
else {
// g
// p u
// c
RotateL(parent);
RotateR(grandfather);
cur->_col = BLACK;
grandfather->_col = RED;
}
break;
}
}
else {
Node* uncle = grandfather->_left;
if (uncle && uncle->_col == RED) {
// 情况一:uncle存在且为红
parent->_col = uncle->_col = BLACK;
grandfather->_col = RED;
cur = grandfather;
parent = cur->_parent;
}
else {
//情况二:u不存在或者u为黑
if (parent->_right == cur) {
// g
// u p
// c
RotateL(grandfather);
parent->_col = BLACK;
grandfather->_col = RED;
}
else {
// g
// u p
// c
RotateR(parent);
RotateL(grandfather);
cur->_col = BLACK;
grandfather->_col = RED;
}
break;
}
}
}
_root->_col = BLACK;
return make_pair(iterator(newnode), false);
}
void RotateL(Node* parent)
{
Node* subR = parent->_right;
Node* subRL = subR->_left;
parent->_right = subRL;
if (subRL)
subRL->_parent = parent;
subR->_left = parent;
Node* ppnode = parent->_parent;
parent->_parent = subR;
if (parent == _root)
{
_root = subR;
subR->_parent = nullptr;
}
else
{
if (ppnode->_left == parent)
{
ppnode->_left = subR;
}
else
{
ppnode->_right = subR;
}
subR->_parent = ppnode;
}
}
void RotateR(Node* parent)
{
Node* subL = parent->_left;
Node* subLR = subL->_right;
parent->_left = subLR;
if (subLR)
subLR->_parent = parent;
subL->_right = parent;
Node* ppnode = parent->_parent;
parent->_parent = subL;
if (parent == _root)
{
_root = subL;
subL->_parent = nullptr;
}
else
{
if (ppnode->_left == parent)
{
ppnode->_left = subL;
}
else
{
ppnode->_right = subL;
}
subL->_parent = ppnode;
}
}
bool Check(Node* root, int BlackNum, int RefBlackNum) {
if (root == nullptr) {
if (BlackNum != RefBlackNum) {
cout << "黑色个数不相等" << endl;
return false;
}
return true;
}
if (root->_col == BLACK)
BlackNum++;
if (root->_col == RED && root->_parent->_col == RED) {
cout << "出现连续红色节点" << endl;
return false;
}
return Check(root->_left, BlackNum, RefBlackNum)
&& Check(root->_right, BlackNum, RefBlackNum);
}
bool IsBalance() {
if (_root && _root->_col == RED)
return false;
int RefBlackNum = 0;
Node* cur = _root;
while (cur) {
if (cur->_col == BLACK)
RefBlackNum++;
cur = cur->_left;
}
return Check(_root, 0, RefBlackNum);
}
private:
Node* _root = nullptr;
};
MyMap.h:
#pragma once
#include"RBTree.h"
namespace lkt {
template<class K, class V>
class map {
struct MapKeyOfT {
const K& operator()(const pair<K, V>& kv) {
return kv.first;
}
};
public:
typedef typename RBTree<K, pair<const K, V>, MapKeyOfT>::iterator iterator;
typedef typename RBTree<K, pair<const K, V>, MapKeyOfT>::const_iterator const_iterator;
iterator begin() {
return _t.begin();
}
iterator end() {
return _t.end();
}
pair<iterator, bool> insert(const pair<K, V>& kv) {
return _t.Insert(kv);
}
iterator find(const K& key) {
return _t.Find(key);
}
V& operator[] (const K& key) {
pair<iterator, bool> ret = insert(make_pair(key, V()));
return ret.first->second;
}
private:
RBTree<K, pair<const K, V>, MapKeyOfT> _t;
};
void TestMap() {
map<int, int> m;
int a[] = { 4, 2, 6, 1, 3, 5, 15, 7, 16, 14 };
for (auto e : a)
{
m.insert(make_pair(e, e));
}
map<int, int>::iterator it = m.begin();
while (it != m.end())
{
//it->first += 100;
it->second += 100;
cout << it->first << ":" << it->second << endl;
++it;
}
cout << endl;
}
}
MySet.h:
#pragma once
#include "RBTree.h"
namespace lkt {
template<class K>
class set {
struct SetKeyOfT {
const K& operator() (const K& key) {
return key;
}
};
public:
typedef typename RBTree<K, const K, SetKeyOfT>::iterator iterator;
typedef typename RBTree<K, const K, SetKeyOfT>::const_iterator const_iterator;
iterator begin() {
return _t.begin();
}
iterator end() {
return _t.end();
}
pair<iterator, bool> insert(const K& key) {
return _t.Insert(key);
}
iterator find(const K& key) {
return _t.Find(key);
}
private:
RBTree<K, const K, SetKeyOfT> _t;
};
void TestSet() {
set<int> s;
int a[] = { 4, 2, 6, 1, 3, 5, 15, 7, 16, 14 };
for (auto e : a)
{
s.insert(e);
}
set<int>::iterator it = s.begin();
while (it != s.end()) {
cout << *it << " ";
++it;
}
cout << endl;
}
}
test.cpp:
#define _CRT_SECURE_NO_WARNINGS 1
#include"MySet.h"
#include "MyMap.h"
int main() {
lkt::TestSet();
lkt::TestMap();
return 0;
}