Golang实现的二叉搜索树

写了一个二叉查找树,递归实现。Go的nil有点坑,和Java区别有点大,用起来没那么爽,不过也减少了大量的空指针错误,可能这也是go的思想吧

代码

type Key interface {
    CompareTo(that Key) int
}
type BST struct {
    root *node
}

type node struct {
    key         Key
    val         interface{}
    num         int
    left, right *node
}

func NewBST() *BST {
    return &BST{}
}


func (n *BST)Size() int {
    return n.size(n.root)
}
func (n *BST)size(x *node) int {
    if x == nil {
        return 0
    }
    return x.num
}

func (n *BST)Get(key Key) interface{} {
    return n.get(n.root, key)
}

func (n *BST)get(x *node, key Key) interface{} {
    if x == nil {
        return nil
    }
    cmp := key.CompareTo(x.key)
    if cmp < 0 {
        return n.get(x.left, key)
    } else if cmp > 0 {
        return n.get(x.right, key)
    } else {
        return x.val
    }
}

func (n *BST)Put(key Key, val interface{}) {
    n.root = n.put(n.root, key, val)
}

func (n *BST)put(x *node, key Key, val interface{}) *node {
    if x == nil {
        return &node{key, val, 1, nil, nil}
    }
    cmp := key.CompareTo(x.key)
    if cmp < 0 {
        x.left = n.put(x.left, key, val)
    } else if cmp > 0 {
        x.right = n.put(x.right, key, val)
    } else {
        x.val = val
    }
    x.num = n.size(x.left) + n.size(x.right) + 1
    return x
}

func (n *BST)Min() Key {
    return n.min(n.root).key
}

func (n *BST)min(x *node) *node {
    if x.left == nil {
        return x
    }
    return n.min(x.left)
}

func (n *BST)Max() Key {
    return n.max(n.root).key
}

func (n *BST)max(x *node) *node {
    if x.right == nil {
        return x
    }
    return n.max(x.right)
}

func (n *BST)Floor(key Key) Key {
    result := n.floor(n.root, key)
    if result == nil {
        return nil
    }
    return result.key
}

func (n *BST)floor(x *node, key Key) *node {
    if x == nil {
        return nil
    }
    cmp := key.CompareTo(x.key)
    if cmp == 0 {
        return x
    }
    if cmp < 0 {
        return n.floor(x.left, key)
    }
    t := n.floor(x.right, key)
    if t != nil {
        return t
    } else {
        return x
    }
}

func (n *BST)Ceiling(key Key) Key{
    result :=n.ceiling(n.root,key)
    if result == nil{
        return nil
    }
    return result.key
}

func (n *BST)ceiling(x *node, key Key) *node {
    if x == nil {
        return nil
    }
    cmp := key.CompareTo(x.key)
    if cmp == 0 {
        return x
    }
    if cmp > 0 {
        return n.ceiling(x.right, key)
    }
    t := n.ceiling(x.left, key)
    if t != nil {
        return t
    } else {
        return x
    }
}


func (n *BST)SelectNode(k int) Key{
    result:=n.selectNode(n.root,k)
    if result!=nil{
        return result.key
    }else{
        return nil
    }
}

func (n *BST)selectNode(x *node,k int) *node{
    if x == nil{
        return nil
    }
    t:=n.size(x.left)
    if t >k{
        return n.selectNode(x.left,k)
    }else if t <k{
        return n.selectNode(x.right,k-t-1)
    }else {
        return x
    }
}

func (n *BST)Rank(key Key) int{
    return n.rank(n.root,key)
}

func (n *BST)rank(x *node,key Key) int{
    if x == nil{
        return 0
    }
    cmp:=key.CompareTo(x.key)
    if cmp <0{
        return n.rank(x.left,key)
    }else if cmp >0{
        return n.rank(x.right,key)+n.size(x.left)+1
    }else {
        return n.size(x.left)
    }
}

func (n *BST)DeleteMin(){
    n.root = n.deleteMin(n.root)
}

func (n *BST)deleteMin(x *node) *node{
    if x.left ==nil{
        return x.right
    }
    x.left = n.deleteMin(x.left)
    x.num = n.size(x.left)+n.size(x.right)+1
    return x
}

func (n *BST)Delete(key Key){
    n.root = n.delete(n.root,key)
}

func (n *BST)delete(x *node,key Key) *node{
    if x == nil{
        return nil
    }
    cmp:=key.CompareTo(x.key)
    if cmp <0{
        x.left = n.delete(x.left,key)
    }else if cmp >0{
        x.right = n.delete(x.right,key)
    }else{
        if x.right == nil{
            return x.left
        }
        if x.left == nil{
            return x.right
        }
        t:=x
        x = n.min(t.right)
        x.right = n.deleteMin(t.right)
        x.left = t.left
    }
    x.num = n.size(x.left)+n.size(x.right)+1
    return x
}
文章目录
  1. 1. 代码