Remove Node in Binary Search Tree
Given a root of Binary Search Tree with unique value for each node. Remove the node with given value. If there is no such a node with given value in the binary search tree, do nothing. You should keep the tree still a binary search tree after removal.
Discussion
刚开始没有想到用recursion,迭代写的漏洞百出。
Solution 1 iteration -- NOT CORRECT
class Solution {
public:
/**
* @param root: The root of the binary search tree.
* @param value: Remove the node with given value.
* @return: The root of the binary search tree after removal.
*/
TreeNode* removeNode(TreeNode* root, int value) {
if(root == NULL) return NULL;
if(root->left == NULL && root->right == NULL) {
if(root->val == value) {
delete root;
return NULL;
}
return root;
}
TreeNode dummy(INT_MAX);
dummy.left = root;
//1. find the node of value
TreeNode *cur = &dummy;
TreeNode *parent = &dummy;
bool is_right = false;
while(cur) {
parent = cur;
if(cur->val > value) {
is_right = false;
cur = cur->left;
} else if(cur->val < value) {
is_right = true;
cur = cur->right;
} else {//what if find it here????
break;
}
}
if(cur == NULL) return NULL;//not found value
//2. remove the node
//2.1: leaf node
if(cur->left == NULL && cur->right == NULL) {
if(is_right){
parent->right = NULL; //cut it
}
delete cur;
return dummy.left;
}
//2.2 no right child, then its left child (if has) will be its father's child
if(cur->right == NULL) {
if(is_right) {
parent->right = cur->left;
cur->left = NULL;
} else {
parent->left = cur->left;
cur->left = NULL;
}
delete cur;
} else {//2.3 cur has right child, find its most left child, connect cur->left to it
TreeNode *succ = cur->right;
while(succ->left != NULL) {
succ = succ->left;
}
//connect cur->left to succ->left
succ->left = cur->left;
cur->left = NULL;
//connect parent to cur->right, then delete cur
if(is_right) {
parent->right = cur->right;
} else {
parent->left = cur->right;
}
cur->right = NULL;
delete cur;
}
return dummy.left;
}
};
Solution 2 recursion
用递归写就简单多了
// Time: O(h)
// Space: O(h)
class Solution {
public:
/**
* @param root: The root of the binary search tree.
* @param value: Remove the node with given value.
* @return: The root of the binary search tree after removal.
*/
TreeNode* removeNode(TreeNode* root, int value) {
if(root == NULL) return NULL;
if(root->val > value) {
root->left = removeNode(root->left, value);
return root;
} else if(root->val < value) {
root->right = removeNode(root->right, value);
return root;
} else {
if(root->left == NULL) {
TreeNode *tmp = root->right;
delete root;
return tmp;
}
if(root->right == NULL) {
TreeNode *tmp = root->left;
delete root;
return tmp;
}
//both left and right are not NULL
//connect left to the lfet side of the min node in right substree
//1. find min nod in right subtree
TreeNode *succ = root->right;
while(succ->left != NULL) {
succ = succ->left;
}
//2. connect left of root to the min node
succ->left = root->left;
root->left = NULL;
TreeNode *tmp = root->right;
delete root;
return tmp;
}
}
};
还有个偷懒的办法是重新构建BST,inorder遍历先存下BST中除了要删除的那个node意外的所有nodes,然后重新构建BST。时间空间都是O(N)的复杂度了。
/**
* Definition of TreeNode:
* class TreeNode {
* public:
* int val;
* TreeNode *left, *right;
* TreeNode(int val) {
* this->val = val;
* this->left = this->right = NULL;
* }
* }
*/
class Solution {
public:
/**
* @param root: The root of the binary search tree.
* @param value: Remove the node with given value.
* @return: The root of the binary search tree after removal.
*/
TreeNode* buildTree(vector<TreeNode*> &v, int left, int right) {
if (left > right) return NULL;
int mid = left + ((right - left) >> 1);
v[mid]->left = buildTree(v, left, mid - 1);
v[mid]->right = buildTree(v, mid + 1, right);
return v[mid];
}
TreeNode* removeNode(TreeNode* root, int value) {
// write your code here
vector<TreeNode*> v;
TreeNode *cur = root, *tmp;
stack<TreeNode*> stk;
while (cur != NULL || !stk.empty()) {
if (cur != NULL) {
stk.push(cur);
cur = cur->left;
} else {
cur = stk.top();
stk.pop();
tmp = cur;
cur = cur->right;
if (tmp->val != value) {
v.push_back(tmp);
} else {
delete tmp;
tmp = NULL;
}
}
}
return buildTree(v, 0, (int)v.size() - 1);
}
};