Loading [MathJax]/extensions/TeX/newcommand.js
\newcommand{\ord}[1]{\mathcal{O}\left(#1\right)} \newcommand{\abs}[1]{\lvert #1 \rvert} \newcommand{\floor}[1]{\lfloor #1 \rfloor} \newcommand{\ceil}[1]{\lceil #1 \rceil} \newcommand{\opord}{\operatorname{\mathcal{O}}} \newcommand{\argmax}{\operatorname{arg\,max}} \newcommand{\str}[1]{\texttt{"#1"}}

2016年2月25日 星期四

[ dynamic kd tree ] 動態kd樹模板

參考資料:
kd樹的資料在網路上其實滿多的,但很多code效率不高或是功能不全,所以這邊稍微講解一下kd樹基本的一些操作
  1. 構造
    現在有一個point序列S[0~n-1],每個point都包含一個序列d[0~kd-1]表示維度,S[i].d[j]表示第i個點的第j維。
    節點的資料型別定義為node,裡面有l,r兩個指標分別指向左右兩顆子樹,還有一個point pid表示所存的point。
    構造的過程是一個遞迴結構,大致代碼如下:
    node* build(int k,int l,int r){
    if(l>r)return NULL;
    if(k==kd)k=0;
    int mid=(l+r)/2;
    找出S中對第k維排序的中位數M
    並把<M的point排在M左邊,>M的排右邊(相當於std::nth_element()操作)
    node *ret=S[mid];
    ret->l=build(k+1,l,mid-1);
    ret->r=build(k+1,mid+1,r);
    return ret;
    }
    當前這個節點的k稱為分裂維度
  2. 查找
    kd樹支援兩種查找法:
    一、給定一個點p和一個數字k,求離p前k近的點有哪些
    二、給定一個範圍,求範圍內有哪些點
    兩種方法跟一般爆搜有點像,但是利用了kd樹可以做到有效的剪枝,有時候可以判斷被切分出來的範圍內有沒有機會存在前k近點或是在範圍內,直接看模板應該就懂了
  3. 刪除
    模板裡沒有附刪除的code所以會寫在這裡。
    首先如果找到要刪除的節點是葉節點直接刪除就好了;如果不是葉節點,假設現在這個點的分裂維度是k,就拿他右子樹第k維最小節點mi去替代他,接著遞迴刪除右子樹的mi;如果沒有右子樹,就拿他左子樹第k維最小節點mi去替代他,然後把左右子樹互換,接著遞迴刪除右子樹的mi。

    找到要刪除的點用查找中的第一種方法就好了,這裡p=要刪除的點,k=1,查找的時候順便維護最近節點其位置與其對p的距離,若最近距離!=0則刪除失敗。

    那對一個子樹找第k維最小節點呢?方法很簡單,也是遞迴定義的:
    首先如果當前節點o的分裂維度剛好是第k維,則若o有右子樹的話答案必在o的右子樹,否則答案為o,如果o的分裂維度!=k,則遞迴搜尋左右子樹,把得到的答案和o自己進行比較,求最小。
    接下來附上只有刪除操作的模板:
    #ifndef SUNMOON_DYNEMIC_KD_TREE
    #define SUNMOON_DYNEMIC_KD_TREE
    #include<algorithm>
    #include<vector>
    template<typename T,size_t kd>
    class kd_tree{
    public:
    struct point{
    T d[kd];
    inline T dist(const point &x)const{
    T ret=0;
    for(size_t i=0;i<kd;++i)ret+=std::abs(d[i]-x.d[i]);
    return ret;
    }
    inline bool operator==(const point &p){
    for(size_t i=0;i<kd;++i){
    if(d[i]!=p.d[i])return 0;
    }
    return 1;
    }
    inline bool operator<(const point &b)const{
    return d[0]<b.d[0];
    }
    };
    private:
    struct node{
    node *l,*r;
    point pid;
    node(const point &p):l(0),r(0),pid(p){}
    }*root;
    const T INF;
    std::vector<node*> A;
    int s;
    struct __cmp{
    int sort_id;
    inline bool operator()(const node*x,const node*y)const{
    return operator()(x->pid,y->pid);
    }
    inline bool operator()(const point &x,const point &y)const{
    if(x.d[sort_id]!=y.d[sort_id])
    return x.d[sort_id]<y.d[sort_id];
    for(size_t i=0;i<kd;++i){
    if(x.d[i]!=y.d[i])return x.d[i]<y.d[i];
    }
    return 0;
    }
    }cmp;
    void clear(node *o){
    if(!o)return;
    clear(o->l);
    clear(o->r);
    delete o;
    }
    node* build(int k,int l,int r){
    if(l>r)return 0;
    if(k==kd)k=0;
    int mid=(l+r)/2;
    cmp.sort_id=k;
    std::nth_element(A.begin()+l,A.begin()+mid,A.begin()+r+1,cmp);
    node *ret=A[mid];
    ret->l=build(k+1,l,mid-1);
    ret->r=build(k+1,mid+1,r);
    return ret;
    }
    inline T heuristic(const T h[])const{
    T ret=0;
    for(size_t i=0;i<kd;++i)ret+=h[i];
    return ret;
    }
    node *findmin(node*o,int k){
    if(!o)return 0;
    if(cmp.sort_id==k)return o->l?findmin(o->l,(k+1)%kd):o;
    node *l=findmin(o->l,(k+1)%kd);
    node *r=findmin(o->r,(k+1)%kd);
    if(l&&!r)return cmp(l,o)?l:o;
    if(!l&&r)return cmp(r,o)?r:o;
    if(!l&&!r)return o;
    if(cmp(l,r))return cmp(l,o)?l:o;
    return cmp(r,o)?r:o;
    }
    bool erase(node *&u,int k,const point &x){
    if(!u)return 0;
    if(u->pid==x){
    if(u->r);
    else if(u->l){
    u->r=u->l;
    u->l=0;
    }else{
    delete u;
    u=0;
    return 1;
    }
    cmp.sort_id=k;
    u->pid=findmin(u->r,(k+1)%kd)->pid;
    return erase(u->r,(k+1)%kd,u->pid);
    }
    cmp.sort_id=k;
    return erase(cmp(x,u->pid)?u->l:u->r,(k+1)%kd,x);
    }
    void nearest_for_erase(node *&u,int k,const point &x,T *h,T &mndist){
    if(u==0||heuristic(h)>=mndist)return;
    T dist=u->pid.dist(x),old=h[k];
    if(dist<mndist){
    if(!(mndist=dist))return;
    }
    if(x.d[k]<u->pid.d[k]){
    nearest_for_erase(u->l,(k+1)%kd,x,h,mndist);
    h[k]=std::abs(x.d[k]-u->pid.d[k]);
    nearest_for_erase(u->r,(k+1)%kd,x,h,mndist);
    }else{
    nearest_for_erase(u->r,(k+1)%kd,x,h,mndist);
    h[k]=std::abs(x.d[k]-u->pid.d[k]);
    nearest_for_erase(u->l,(k+1)%kd,x,h,mndist);
    }
    h[k]=old;
    }
    public:
    kd_tree(const T &INF):root(0),INF(INF),s(0){}
    inline void clear(){
    clear(root),root=0;
    }
    inline void build(int n,const point *p){
    clear(root),A.resize(s=n);
    for(int i=0;i<n;++i)A[i]=new node(p[i]);
    root=build(0,0,n-1);
    }
    inline bool erase(const point &p){
    return erase(root,0,p);
    }
    inline T nearest(const point &x){
    T mndist=INF,h[kd]={};
    nearest_for_erase(root,0,x,h,mndist);
    return mndist;/*回傳離x最近的點的距離*/
    }
    inline int size(){return s;}
    };
    #endif

插入刪除通常配合替罪羊樹使用,插入很簡單模板有就不說了,接下來講一下複雜度:
設N=size(root)
  • 構造:
    很明顯就是\ord{N \; log \; N}就不說了
  • 查找:
    已經有人證明了,假設現在的維度有k維,則查找的最差複雜度是\ord{N^{1-1/k}}
    但是平均狀況下為\ord{log \; N}
  • 插入:
    複雜度為\ord{樹的高度},因為是套替罪羊樹,而且重構的時間是\ord{N \; log \; N},所以單次插入的均攤時間複雜度是\ord{log \; N \; * \; log \; N}
  • 刪除:
    有三個步驟需要分析,假設現在的維度有k維
    1. findmin:
      最多會遍歷\alpha^{(k-1)*(log_{\alpha}N)/k}=N^{1-1/k}個節點,所以是\ord{N^{1-1/k}}
      但實際操作量為N^{1-1/k}-n^{1-1/k},n是最小值的子樹大小
    2. nearest:
      就是查找操作,所以是\ord{N^{1-1/k}}
      但因為是找相同點,可以優化code,所以實際操作量為N^{1-1/k}-n^{1-1/k},n是相同點的子樹大小
    3. 刪除操作本身:
      看code明顯就是重複findmin直到把刪除的點變成葉節點為止,會感覺做了很多操作,但是我們發現把所有操作加起來後會=\ord{N^{1-1/k}}
      o_1,o_2,...,o_d=findmin(root)找到的點,findmin(o_1)找到的點,...,要刪除的葉節點,n_1,n_2,...,n_d=findmin(root)找到的點的子樹大小,findmin(o_1)找到的點的子樹大小,findmin(o_2)找到的點的子樹大小,...,要刪除的葉節點的子樹大小,d為樹的高度
      複雜度為:
      N^{1-1/k}-n1^{1-1/k}+n1^{1-1/k}-n2^{1-1/k}+n2^{1-1/k}...-nd^{1-1/k}=\ord{N^{1-1/k}}
模板裡的code找的是曼哈頓距離,如果要找歐基里德距離只需要修改point裡的dist跟kd tree裡的heuristic即可,以下提供修改方法:
inline T dist(const point &x)const{
T ret=0;
for(size_t i=0;i<kd;++i)ret+=(d[i]-x.d[i])*(d[i]-x.d[i]);
return ret;
}
inline T heuristic(const T h[])const{
T ret=0;
for(size_t i=0;i<kd;++i)ret+=h[i]*h[i];
return ret;
}
這樣查找時回傳的就是歐基里德距離的平方

以下附只有插入操作的模板:
#ifndef SUNMOON_DYNEMIC_KD_TREE
#define SUNMOON_DYNEMIC_KD_TREE
#include<algorithm>
#include<vector>
#include<queue>
#include<cmath>
template<typename T,size_t kd>//kd表示有幾個維度
class kd_tree{
public:
struct point{
T d[kd];
inline T dist(const point &x)const{
T ret=0;
for(size_t i=0;i<kd;++i)ret+=std::abs(d[i]-x.d[i]);
return ret;
}
inline bool operator<(const point &b)const{
return d[0]<b.d[0];
}
};
private:
struct node{
node *l,*r;
point pid;
int s;
node(const point &p):l(0),r(0),pid(p),s(1){}
inline void up(){
s=(l?l->s:0)+1+(r?r->s:0);
}
}*root;
const double alpha,loga;
const T INF;//記得要給INF,表示極大值
std::vector<node*> A;
int qM;
std::priority_queue<std::pair<T,point > >pQ;
struct __cmp{
int sort_id;
inline bool operator()(const node*x,const node*y)const{
return x->pid.d[sort_id]<y->pid.d[sort_id];
}
}cmp;
void clear(node *o){
if(!o)return;
clear(o->l);
clear(o->r);
delete o;
}
inline int size(node *o){
return o?o->s:0;
}
node* build(int k,int l,int r){
if(l>r)return 0;
if(k==kd)k=0;
int mid=(l+r)/2;
cmp.sort_id=k;
std::nth_element(A.begin()+l,A.begin()+mid,A.begin()+r+1,cmp);
node *ret=A[mid];
ret->l=build(k+1,l,mid-1);
ret->r=build(k+1,mid+1,r);
ret->up();
return ret;
}
inline bool isbad(node*o){
return size(o->l)>alpha*o->s||size(o->r)>alpha*o->s;
}
void flatten(node *u,typename std::vector<node*>::iterator &it){
if(!u)return;
flatten(u->l,it);
*it=u;
flatten(u->r,++it);
}
bool insert(node*&u,int k,const point &x,int dep){
if(!u){
u=new node(x);
return dep<=0;
}
++u->s;
if(insert(x.d[k]<u->pid.d[k]?u->l:u->r,(k+1)%kd,x,dep-1)){
if(!isbad(u))return 1;
if((int)A.size()<u->s)A.resize(u->s);
typename std::vector<node*>::iterator it=A.begin();
flatten(u,it);
u=build(k,0,u->s-1);
}
return 0;
}
inline T heuristic(const T h[])const{
T ret=0;
for(size_t i=0;i<kd;++i)ret+=h[i];
return ret;
}
void nearest(node *u,int k,const point &x,T *h,T &mndist){
if(u==0||heuristic(h)>=mndist)return;
T dist=u->pid.dist(x),old=h[k];
/*mndist=std::min(mndist,dist);*/
if(dist<mndist){
pQ.push(std::make_pair(dist,u->pid));
if((int)pQ.size()==qM+1){
mndist=pQ.top().first,pQ.pop();
}
}
if(x.d[k]<u->pid.d[k]){
nearest(u->l,(k+1)%kd,x,h,mndist);
h[k]=std::abs(x.d[k]-u->pid.d[k]);
nearest(u->r,(k+1)%kd,x,h,mndist);
}else{
nearest(u->r,(k+1)%kd,x,h,mndist);
h[k]=std::abs(x.d[k]-u->pid.d[k]);
nearest(u->l,(k+1)%kd,x,h,mndist);
}
h[k]=old;
}
public:
kd_tree(const T &INF,double a=0.75):root(0),alpha(a),loga(log2(1.0/a)),INF(INF){}
inline void clear(){
clear(root),root=0;
}
inline void build(int n,const point *p){
clear(root),A.resize(n);
for(int i=0;i<n;++i)A[i]=new node(p[i]);
root=build(0,0,n-1);
}
inline void insert(const point &x){
insert(root,0,x,std::__lg(size(root))/loga);
}
inline T nearest(const point &x,int k){
qM=k;
T mndist=INF,h[kd]={};
nearest(root,0,x,h,mndist);
mndist=pQ.top().first;
pQ=std::priority_queue<std::pair<T,point > >();
return mndist;/*回傳離x第k近的點的距離*/
}
inline int size(){return root?root->s:0;}
};
#endif

以下附支援插入刪除的模板:
#ifndef SUNMOON_DYNEMIC_KD_TREE
#define SUNMOON_DYNEMIC_KD_TREE
#include<algorithm>
#include<vector>
#include<queue>
#include<cmath>
template<typename T,size_t kd>//kd表示有幾個維度
class kd_tree{
public:
struct point{
T d[kd];
inline T dist(const point &x)const{
T ret=0;
for(size_t i=0;i<kd;++i)ret+=std::abs(d[i]-x.d[i]);
return ret;
}
inline bool operator==(const point &p){
for(size_t i=0;i<kd;++i){
if(d[i]!=p.d[i])return 0;
}
return 1;
}
inline bool operator<(const point &b)const{
return d[0]<b.d[0];
}
};
private:
struct node{
node *l,*r;
point pid;
int s;
node(const point &p):l(0),r(0),pid(p),s(1){}
inline void up(){
s=(l?l->s:0)+1+(r?r->s:0);
}
}*root;
const double alpha,loga;
const T INF;//記得要給INF,表示極大值
int maxn;
struct __cmp{
int sort_id;
inline bool operator()(const node*x,const node*y)const{
return operator()(x->pid,y->pid);
}
inline bool operator()(const point &x,const point &y)const{
if(x.d[sort_id]!=y.d[sort_id])
return x.d[sort_id]<y.d[sort_id];
for(size_t i=0;i<kd;++i){
if(x.d[i]!=y.d[i])return x.d[i]<y.d[i];
}
return 0;
}
}cmp;
void clear(node *o){
if(!o)return;
clear(o->l);
clear(o->r);
delete o;
}
inline int size(node *o){
return o?o->s:0;
}
std::vector<node*> A;
node* build(int k,int l,int r){
if(l>r)return 0;
if(k==kd)k=0;
int mid=(l+r)/2;
cmp.sort_id=k;
std::nth_element(A.begin()+l,A.begin()+mid,A.begin()+r+1,cmp);
node *ret=A[mid];
ret->l=build(k+1,l,mid-1);
ret->r=build(k+1,mid+1,r);
ret->up();
return ret;
}
inline bool isbad(node*o){
return size(o->l)>alpha*o->s||size(o->r)>alpha*o->s;
}
void flatten(node *u,typename std::vector<node*>::iterator &it){
if(!u)return;
flatten(u->l,it);
*it=u;
flatten(u->r,++it);
}
inline void rebuild(node*&u,int k){
if((int)A.size()<u->s)A.resize(u->s);
typename std::vector<node*>::iterator it=A.begin();
flatten(u,it);
u=build(k,0,u->s-1);
}
bool insert(node*&u,int k,const point &x,int dep){
if(!u){
u=new node(x);
return dep<=0;
}
++u->s;
cmp.sort_id=k;
if(insert(cmp(x,u->pid)?u->l:u->r,(k+1)%kd,x,dep-1)){
if(!isbad(u))return 1;
rebuild(u,k);
}
return 0;
}
node *findmin(node*o,int k){
if(!o)return 0;
if(cmp.sort_id==k)return o->l?findmin(o->l,(k+1)%kd):o;
node *l=findmin(o->l,(k+1)%kd);
node *r=findmin(o->r,(k+1)%kd);
if(l&&!r)return cmp(l,o)?l:o;
if(!l&&r)return cmp(r,o)?r:o;
if(!l&&!r)return o;
if(cmp(l,r))return cmp(l,o)?l:o;
return cmp(r,o)?r:o;
}
bool erase(node *&u,int k,const point &x){
if(!u)return 0;
if(u->pid==x){
if(u->r);
else if(u->l){
u->r=u->l;
u->l=0;
}else{
delete u;
u=0;
return 1;
}
--u->s;
cmp.sort_id=k;
u->pid=findmin(u->r,(k+1)%kd)->pid;
return erase(u->r,(k+1)%kd,u->pid);
}
cmp.sort_id=k;
if(erase(cmp(x,u->pid)?u->l:u->r,(k+1)%kd,x)){
--u->s;return 1;
}else return 0;
}
inline T heuristic(const T h[])const{
T ret=0;
for(size_t i=0;i<kd;++i)ret+=h[i];
return ret;
}
int qM;
std::priority_queue<std::pair<T,point > >pQ;
void nearest(node *u,int k,const point &x,T *h,T &mndist){
if(u==0||heuristic(h)>=mndist)return;
T dist=u->pid.dist(x),old=h[k];
/*mndist=std::min(mndist,dist);*/
if(dist<mndist){
pQ.push(std::make_pair(dist,u->pid));
if((int)pQ.size()==qM+1){
mndist=pQ.top().first,pQ.pop();
}
}
if(x.d[k]<u->pid.d[k]){
nearest(u->l,(k+1)%kd,x,h,mndist);
h[k]=std::abs(x.d[k]-u->pid.d[k]);
nearest(u->r,(k+1)%kd,x,h,mndist);
}else{
nearest(u->r,(k+1)%kd,x,h,mndist);
h[k]=std::abs(x.d[k]-u->pid.d[k]);
nearest(u->l,(k+1)%kd,x,h,mndist);
}
h[k]=old;
}
std::vector<point>in_range;
void range(node *u,int k,const point&mi,const point&ma){
if(!u)return;
bool is=1;
for(int i=0;i<kd;++i)
if(u->pid.d[i]<mi.d[i]||ma.d[i]<u->pid.d[i]){
is=0;break;
}
if(is)in_range.push_back(u->pid);
if(mi.d[k]<=u->pid.d[k])range(u->l,(k+1)%kd,mi,ma);
if(ma.d[k]>=u->pid.d[k])range(u->r,(k+1)%kd,mi,ma);
}
public:
kd_tree(const T &INF,double a=0.75):root(0),alpha(a),loga(log2(1.0/a)),INF(INF),maxn(1){}
inline void clear(){
clear(root),root=0,maxn=1;
}
inline void build(int n,const point *p){
clear(root),A.resize(maxn=n);
for(int i=0;i<n;++i)A[i]=new node(p[i]);
root=build(0,0,n-1);
}
inline void insert(const point &x){
insert(root,0,x,std::__lg(size(root))/loga);
if(root->s>maxn)maxn=root->s;
}
inline bool erase(const point &p){
bool d=erase(root,0,p);
if(root&&root->s<alpha*maxn)rebuild();
return d;
}
inline void rebuild(){
if(root)rebuild(root,0);
maxn=root->s;
}
inline T nearest(const point &x,int k){
qM=k;
T mndist=INF,h[kd]={};
nearest(root,0,x,h,mndist);
mndist=pQ.top().first;
pQ=std::priority_queue<std::pair<T,point > >();
return mndist;/*回傳離x第k近的點的距離*/
}
inline const std::vector<point> &range(const point&mi,const point&ma){
in_range.clear();
range(root,0,mi,ma);
return in_range;/*回傳介於mi到ma之間的點vector*/
}
inline int size(){return root?root->s:0;}
};
#endif
view raw kd_tree.cpp hosted with ❤ by GitHub
模板的使用方法請參考:

日月卦長的解題紀錄 [ IOICAMP2016 ] 動態曼哈頓最短距離


1 則留言:


  1. 卦長想看你寫的二元搜尋樹修改的code.
    複雜度要跟插入刪除一樣不然就沒意義了

    回覆刪除