- k-d tree教學投影片
- http://andrewd.ces.clemson.edu/courses/cpsc805/references/nearest_search.pdf
- https://cw.fel.cvut.cz/wiki/_media/courses/a4m33pal/paska13.pdf
- 在二維空間的查找優化
- K-D Tree在信息学竞赛中的应用
- K-D tree的估价
- 从K近邻算法、距离度量谈到KD树、SIFT+BBF算法
- 感謝Morris大大的模板
- 構造
現在有一個point序列S[0~n-1],每個point都包含一個序列d[0~kd-1]表示維度,S[i].d[j]表示第i個點的第j維。
節點的資料型別定義為node,裡面有l,r兩個指標分別指向左右兩顆子樹,還有一個point pid表示所存的point。
構造的過程是一個遞迴結構,大致代碼如下:
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersnode* build(int k,int l,int r){ if(l>r)return NULL; if(k==kd)k=0; int mid=(l+r)/2; 找出S中對第k維排序的中位數M 並把<M的point排在M左邊,>M的排右邊(相當於std::nth_element()操作) node *ret=S[mid]; ret->l=build(k+1,l,mid-1); ret->r=build(k+1,mid+1,r); return ret; }
- 查找
kd樹支援兩種查找法:
一、給定一個點p和一個數字k,求離p前k近的點有哪些
二、給定一個範圍,求範圍內有哪些點
兩種方法跟一般爆搜有點像,但是利用了kd樹可以做到有效的剪枝,有時候可以判斷被切分出來的範圍內有沒有機會存在前k近點或是在範圍內,直接看模板應該就懂了 - 刪除
模板裡沒有附刪除的code所以會寫在這裡。
首先如果找到要刪除的節點是葉節點直接刪除就好了;如果不是葉節點,假設現在這個點的分裂維度是k,就拿他右子樹第k維最小節點mi去替代他,接著遞迴刪除右子樹的mi;如果沒有右子樹,就拿他左子樹第k維最小節點mi去替代他,然後把左右子樹互換,接著遞迴刪除右子樹的mi。
找到要刪除的點用查找中的第一種方法就好了,這裡p=要刪除的點,k=1,查找的時候順便維護最近節點其位置與其對p的距離,若最近距離!=0則刪除失敗。
那對一個子樹找第k維最小節點呢?方法很簡單,也是遞迴定義的:
首先如果當前節點o的分裂維度剛好是第k維,則若o有右子樹的話答案必在o的右子樹,否則答案為o,如果o的分裂維度!=k,則遞迴搜尋左右子樹,把得到的答案和o自己進行比較,求最小。
接下來附上只有刪除操作的模板:
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode characters#ifndef SUNMOON_DYNEMIC_KD_TREE #define SUNMOON_DYNEMIC_KD_TREE #include<algorithm> #include<vector> template<typename T,size_t kd> class kd_tree{ public: struct point{ T d[kd]; inline T dist(const point &x)const{ T ret=0; for(size_t i=0;i<kd;++i)ret+=std::abs(d[i]-x.d[i]); return ret; } inline bool operator==(const point &p){ for(size_t i=0;i<kd;++i){ if(d[i]!=p.d[i])return 0; } return 1; } inline bool operator<(const point &b)const{ return d[0]<b.d[0]; } }; private: struct node{ node *l,*r; point pid; node(const point &p):l(0),r(0),pid(p){} }*root; const T INF; std::vector<node*> A; int s; struct __cmp{ int sort_id; inline bool operator()(const node*x,const node*y)const{ return operator()(x->pid,y->pid); } inline bool operator()(const point &x,const point &y)const{ if(x.d[sort_id]!=y.d[sort_id]) return x.d[sort_id]<y.d[sort_id]; for(size_t i=0;i<kd;++i){ if(x.d[i]!=y.d[i])return x.d[i]<y.d[i]; } return 0; } }cmp; void clear(node *o){ if(!o)return; clear(o->l); clear(o->r); delete o; } node* build(int k,int l,int r){ if(l>r)return 0; if(k==kd)k=0; int mid=(l+r)/2; cmp.sort_id=k; std::nth_element(A.begin()+l,A.begin()+mid,A.begin()+r+1,cmp); node *ret=A[mid]; ret->l=build(k+1,l,mid-1); ret->r=build(k+1,mid+1,r); return ret; } inline T heuristic(const T h[])const{ T ret=0; for(size_t i=0;i<kd;++i)ret+=h[i]; return ret; } node *findmin(node*o,int k){ if(!o)return 0; if(cmp.sort_id==k)return o->l?findmin(o->l,(k+1)%kd):o; node *l=findmin(o->l,(k+1)%kd); node *r=findmin(o->r,(k+1)%kd); if(l&&!r)return cmp(l,o)?l:o; if(!l&&r)return cmp(r,o)?r:o; if(!l&&!r)return o; if(cmp(l,r))return cmp(l,o)?l:o; return cmp(r,o)?r:o; } bool erase(node *&u,int k,const point &x){ if(!u)return 0; if(u->pid==x){ if(u->r); else if(u->l){ u->r=u->l; u->l=0; }else{ delete u; u=0; return 1; } cmp.sort_id=k; u->pid=findmin(u->r,(k+1)%kd)->pid; return erase(u->r,(k+1)%kd,u->pid); } cmp.sort_id=k; return erase(cmp(x,u->pid)?u->l:u->r,(k+1)%kd,x); } void nearest_for_erase(node *&u,int k,const point &x,T *h,T &mndist){ if(u==0||heuristic(h)>=mndist)return; T dist=u->pid.dist(x),old=h[k]; if(dist<mndist){ if(!(mndist=dist))return; } if(x.d[k]<u->pid.d[k]){ nearest_for_erase(u->l,(k+1)%kd,x,h,mndist); h[k]=std::abs(x.d[k]-u->pid.d[k]); nearest_for_erase(u->r,(k+1)%kd,x,h,mndist); }else{ nearest_for_erase(u->r,(k+1)%kd,x,h,mndist); h[k]=std::abs(x.d[k]-u->pid.d[k]); nearest_for_erase(u->l,(k+1)%kd,x,h,mndist); } h[k]=old; } public: kd_tree(const T &INF):root(0),INF(INF),s(0){} inline void clear(){ clear(root),root=0; } inline void build(int n,const point *p){ clear(root),A.resize(s=n); for(int i=0;i<n;++i)A[i]=new node(p[i]); root=build(0,0,n-1); } inline bool erase(const point &p){ return erase(root,0,p); } inline T nearest(const point &x){ T mndist=INF,h[kd]={}; nearest_for_erase(root,0,x,h,mndist); return mndist;/*回傳離x最近的點的距離*/ } inline int size(){return s;} }; #endif
設N=size(root)
- 構造:
很明顯就是\ord{N \; log \; N}就不說了 - 查找:
已經有人證明了,假設現在的維度有k維,則查找的最差複雜度是\ord{N^{1-1/k}}
但是平均狀況下為\ord{log \; N} - 插入:
複雜度為\ord{樹的高度},因為是套替罪羊樹,而且重構的時間是\ord{N \; log \; N},所以單次插入的均攤時間複雜度是\ord{log \; N \; * \; log \; N} - 刪除:
有三個步驟需要分析,假設現在的維度有k維- findmin:
最多會遍歷\alpha^{(k-1)*(log_{\alpha}N)/k}=N^{1-1/k}個節點,所以是\ord{N^{1-1/k}}
但實際操作量為N^{1-1/k}-n^{1-1/k},n是最小值的子樹大小 - nearest:
就是查找操作,所以是\ord{N^{1-1/k}}
但因為是找相同點,可以優化code,所以實際操作量為N^{1-1/k}-n^{1-1/k},n是相同點的子樹大小 - 刪除操作本身:
看code明顯就是重複findmin直到把刪除的點變成葉節點為止,會感覺做了很多操作,但是我們發現把所有操作加起來後會=\ord{N^{1-1/k}}
設o_1,o_2,...,o_d=findmin(root)找到的點,findmin(o_1)找到的點,...,要刪除的葉節點,n_1,n_2,...,n_d=findmin(root)找到的點的子樹大小,findmin(o_1)找到的點的子樹大小,findmin(o_2)找到的點的子樹大小,...,要刪除的葉節點的子樹大小,d為樹的高度
複雜度為:
N^{1-1/k}-n1^{1-1/k}+n1^{1-1/k}-n2^{1-1/k}+n2^{1-1/k}...-nd^{1-1/k}=\ord{N^{1-1/k}}
- findmin:
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
inline T dist(const point &x)const{ | |
T ret=0; | |
for(size_t i=0;i<kd;++i)ret+=(d[i]-x.d[i])*(d[i]-x.d[i]); | |
return ret; | |
} | |
inline T heuristic(const T h[])const{ | |
T ret=0; | |
for(size_t i=0;i<kd;++i)ret+=h[i]*h[i]; | |
return ret; | |
} |
以下附只有插入操作的模板:
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#ifndef SUNMOON_DYNEMIC_KD_TREE | |
#define SUNMOON_DYNEMIC_KD_TREE | |
#include<algorithm> | |
#include<vector> | |
#include<queue> | |
#include<cmath> | |
template<typename T,size_t kd>//kd表示有幾個維度 | |
class kd_tree{ | |
public: | |
struct point{ | |
T d[kd]; | |
inline T dist(const point &x)const{ | |
T ret=0; | |
for(size_t i=0;i<kd;++i)ret+=std::abs(d[i]-x.d[i]); | |
return ret; | |
} | |
inline bool operator<(const point &b)const{ | |
return d[0]<b.d[0]; | |
} | |
}; | |
private: | |
struct node{ | |
node *l,*r; | |
point pid; | |
int s; | |
node(const point &p):l(0),r(0),pid(p),s(1){} | |
inline void up(){ | |
s=(l?l->s:0)+1+(r?r->s:0); | |
} | |
}*root; | |
const double alpha,loga; | |
const T INF;//記得要給INF,表示極大值 | |
std::vector<node*> A; | |
int qM; | |
std::priority_queue<std::pair<T,point > >pQ; | |
struct __cmp{ | |
int sort_id; | |
inline bool operator()(const node*x,const node*y)const{ | |
return x->pid.d[sort_id]<y->pid.d[sort_id]; | |
} | |
}cmp; | |
void clear(node *o){ | |
if(!o)return; | |
clear(o->l); | |
clear(o->r); | |
delete o; | |
} | |
inline int size(node *o){ | |
return o?o->s:0; | |
} | |
node* build(int k,int l,int r){ | |
if(l>r)return 0; | |
if(k==kd)k=0; | |
int mid=(l+r)/2; | |
cmp.sort_id=k; | |
std::nth_element(A.begin()+l,A.begin()+mid,A.begin()+r+1,cmp); | |
node *ret=A[mid]; | |
ret->l=build(k+1,l,mid-1); | |
ret->r=build(k+1,mid+1,r); | |
ret->up(); | |
return ret; | |
} | |
inline bool isbad(node*o){ | |
return size(o->l)>alpha*o->s||size(o->r)>alpha*o->s; | |
} | |
void flatten(node *u,typename std::vector<node*>::iterator &it){ | |
if(!u)return; | |
flatten(u->l,it); | |
*it=u; | |
flatten(u->r,++it); | |
} | |
bool insert(node*&u,int k,const point &x,int dep){ | |
if(!u){ | |
u=new node(x); | |
return dep<=0; | |
} | |
++u->s; | |
if(insert(x.d[k]<u->pid.d[k]?u->l:u->r,(k+1)%kd,x,dep-1)){ | |
if(!isbad(u))return 1; | |
if((int)A.size()<u->s)A.resize(u->s); | |
typename std::vector<node*>::iterator it=A.begin(); | |
flatten(u,it); | |
u=build(k,0,u->s-1); | |
} | |
return 0; | |
} | |
inline T heuristic(const T h[])const{ | |
T ret=0; | |
for(size_t i=0;i<kd;++i)ret+=h[i]; | |
return ret; | |
} | |
void nearest(node *u,int k,const point &x,T *h,T &mndist){ | |
if(u==0||heuristic(h)>=mndist)return; | |
T dist=u->pid.dist(x),old=h[k]; | |
/*mndist=std::min(mndist,dist);*/ | |
if(dist<mndist){ | |
pQ.push(std::make_pair(dist,u->pid)); | |
if((int)pQ.size()==qM+1){ | |
mndist=pQ.top().first,pQ.pop(); | |
} | |
} | |
if(x.d[k]<u->pid.d[k]){ | |
nearest(u->l,(k+1)%kd,x,h,mndist); | |
h[k]=std::abs(x.d[k]-u->pid.d[k]); | |
nearest(u->r,(k+1)%kd,x,h,mndist); | |
}else{ | |
nearest(u->r,(k+1)%kd,x,h,mndist); | |
h[k]=std::abs(x.d[k]-u->pid.d[k]); | |
nearest(u->l,(k+1)%kd,x,h,mndist); | |
} | |
h[k]=old; | |
} | |
public: | |
kd_tree(const T &INF,double a=0.75):root(0),alpha(a),loga(log2(1.0/a)),INF(INF){} | |
inline void clear(){ | |
clear(root),root=0; | |
} | |
inline void build(int n,const point *p){ | |
clear(root),A.resize(n); | |
for(int i=0;i<n;++i)A[i]=new node(p[i]); | |
root=build(0,0,n-1); | |
} | |
inline void insert(const point &x){ | |
insert(root,0,x,std::__lg(size(root))/loga); | |
} | |
inline T nearest(const point &x,int k){ | |
qM=k; | |
T mndist=INF,h[kd]={}; | |
nearest(root,0,x,h,mndist); | |
mndist=pQ.top().first; | |
pQ=std::priority_queue<std::pair<T,point > >(); | |
return mndist;/*回傳離x第k近的點的距離*/ | |
} | |
inline int size(){return root?root->s:0;} | |
}; | |
#endif |
以下附支援插入刪除的模板:
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#ifndef SUNMOON_DYNEMIC_KD_TREE | |
#define SUNMOON_DYNEMIC_KD_TREE | |
#include<algorithm> | |
#include<vector> | |
#include<queue> | |
#include<cmath> | |
template<typename T,size_t kd>//kd表示有幾個維度 | |
class kd_tree{ | |
public: | |
struct point{ | |
T d[kd]; | |
inline T dist(const point &x)const{ | |
T ret=0; | |
for(size_t i=0;i<kd;++i)ret+=std::abs(d[i]-x.d[i]); | |
return ret; | |
} | |
inline bool operator==(const point &p){ | |
for(size_t i=0;i<kd;++i){ | |
if(d[i]!=p.d[i])return 0; | |
} | |
return 1; | |
} | |
inline bool operator<(const point &b)const{ | |
return d[0]<b.d[0]; | |
} | |
}; | |
private: | |
struct node{ | |
node *l,*r; | |
point pid; | |
int s; | |
node(const point &p):l(0),r(0),pid(p),s(1){} | |
inline void up(){ | |
s=(l?l->s:0)+1+(r?r->s:0); | |
} | |
}*root; | |
const double alpha,loga; | |
const T INF;//記得要給INF,表示極大值 | |
int maxn; | |
struct __cmp{ | |
int sort_id; | |
inline bool operator()(const node*x,const node*y)const{ | |
return operator()(x->pid,y->pid); | |
} | |
inline bool operator()(const point &x,const point &y)const{ | |
if(x.d[sort_id]!=y.d[sort_id]) | |
return x.d[sort_id]<y.d[sort_id]; | |
for(size_t i=0;i<kd;++i){ | |
if(x.d[i]!=y.d[i])return x.d[i]<y.d[i]; | |
} | |
return 0; | |
} | |
}cmp; | |
void clear(node *o){ | |
if(!o)return; | |
clear(o->l); | |
clear(o->r); | |
delete o; | |
} | |
inline int size(node *o){ | |
return o?o->s:0; | |
} | |
std::vector<node*> A; | |
node* build(int k,int l,int r){ | |
if(l>r)return 0; | |
if(k==kd)k=0; | |
int mid=(l+r)/2; | |
cmp.sort_id=k; | |
std::nth_element(A.begin()+l,A.begin()+mid,A.begin()+r+1,cmp); | |
node *ret=A[mid]; | |
ret->l=build(k+1,l,mid-1); | |
ret->r=build(k+1,mid+1,r); | |
ret->up(); | |
return ret; | |
} | |
inline bool isbad(node*o){ | |
return size(o->l)>alpha*o->s||size(o->r)>alpha*o->s; | |
} | |
void flatten(node *u,typename std::vector<node*>::iterator &it){ | |
if(!u)return; | |
flatten(u->l,it); | |
*it=u; | |
flatten(u->r,++it); | |
} | |
inline void rebuild(node*&u,int k){ | |
if((int)A.size()<u->s)A.resize(u->s); | |
typename std::vector<node*>::iterator it=A.begin(); | |
flatten(u,it); | |
u=build(k,0,u->s-1); | |
} | |
bool insert(node*&u,int k,const point &x,int dep){ | |
if(!u){ | |
u=new node(x); | |
return dep<=0; | |
} | |
++u->s; | |
cmp.sort_id=k; | |
if(insert(cmp(x,u->pid)?u->l:u->r,(k+1)%kd,x,dep-1)){ | |
if(!isbad(u))return 1; | |
rebuild(u,k); | |
} | |
return 0; | |
} | |
node *findmin(node*o,int k){ | |
if(!o)return 0; | |
if(cmp.sort_id==k)return o->l?findmin(o->l,(k+1)%kd):o; | |
node *l=findmin(o->l,(k+1)%kd); | |
node *r=findmin(o->r,(k+1)%kd); | |
if(l&&!r)return cmp(l,o)?l:o; | |
if(!l&&r)return cmp(r,o)?r:o; | |
if(!l&&!r)return o; | |
if(cmp(l,r))return cmp(l,o)?l:o; | |
return cmp(r,o)?r:o; | |
} | |
bool erase(node *&u,int k,const point &x){ | |
if(!u)return 0; | |
if(u->pid==x){ | |
if(u->r); | |
else if(u->l){ | |
u->r=u->l; | |
u->l=0; | |
}else{ | |
delete u; | |
u=0; | |
return 1; | |
} | |
--u->s; | |
cmp.sort_id=k; | |
u->pid=findmin(u->r,(k+1)%kd)->pid; | |
return erase(u->r,(k+1)%kd,u->pid); | |
} | |
cmp.sort_id=k; | |
if(erase(cmp(x,u->pid)?u->l:u->r,(k+1)%kd,x)){ | |
--u->s;return 1; | |
}else return 0; | |
} | |
inline T heuristic(const T h[])const{ | |
T ret=0; | |
for(size_t i=0;i<kd;++i)ret+=h[i]; | |
return ret; | |
} | |
int qM; | |
std::priority_queue<std::pair<T,point > >pQ; | |
void nearest(node *u,int k,const point &x,T *h,T &mndist){ | |
if(u==0||heuristic(h)>=mndist)return; | |
T dist=u->pid.dist(x),old=h[k]; | |
/*mndist=std::min(mndist,dist);*/ | |
if(dist<mndist){ | |
pQ.push(std::make_pair(dist,u->pid)); | |
if((int)pQ.size()==qM+1){ | |
mndist=pQ.top().first,pQ.pop(); | |
} | |
} | |
if(x.d[k]<u->pid.d[k]){ | |
nearest(u->l,(k+1)%kd,x,h,mndist); | |
h[k]=std::abs(x.d[k]-u->pid.d[k]); | |
nearest(u->r,(k+1)%kd,x,h,mndist); | |
}else{ | |
nearest(u->r,(k+1)%kd,x,h,mndist); | |
h[k]=std::abs(x.d[k]-u->pid.d[k]); | |
nearest(u->l,(k+1)%kd,x,h,mndist); | |
} | |
h[k]=old; | |
} | |
std::vector<point>in_range; | |
void range(node *u,int k,const point&mi,const point&ma){ | |
if(!u)return; | |
bool is=1; | |
for(int i=0;i<kd;++i) | |
if(u->pid.d[i]<mi.d[i]||ma.d[i]<u->pid.d[i]){ | |
is=0;break; | |
} | |
if(is)in_range.push_back(u->pid); | |
if(mi.d[k]<=u->pid.d[k])range(u->l,(k+1)%kd,mi,ma); | |
if(ma.d[k]>=u->pid.d[k])range(u->r,(k+1)%kd,mi,ma); | |
} | |
public: | |
kd_tree(const T &INF,double a=0.75):root(0),alpha(a),loga(log2(1.0/a)),INF(INF),maxn(1){} | |
inline void clear(){ | |
clear(root),root=0,maxn=1; | |
} | |
inline void build(int n,const point *p){ | |
clear(root),A.resize(maxn=n); | |
for(int i=0;i<n;++i)A[i]=new node(p[i]); | |
root=build(0,0,n-1); | |
} | |
inline void insert(const point &x){ | |
insert(root,0,x,std::__lg(size(root))/loga); | |
if(root->s>maxn)maxn=root->s; | |
} | |
inline bool erase(const point &p){ | |
bool d=erase(root,0,p); | |
if(root&&root->s<alpha*maxn)rebuild(); | |
return d; | |
} | |
inline void rebuild(){ | |
if(root)rebuild(root,0); | |
maxn=root->s; | |
} | |
inline T nearest(const point &x,int k){ | |
qM=k; | |
T mndist=INF,h[kd]={}; | |
nearest(root,0,x,h,mndist); | |
mndist=pQ.top().first; | |
pQ=std::priority_queue<std::pair<T,point > >(); | |
return mndist;/*回傳離x第k近的點的距離*/ | |
} | |
inline const std::vector<point> &range(const point&mi,const point&ma){ | |
in_range.clear(); | |
range(root,0,mi,ma); | |
return in_range;/*回傳介於mi到ma之間的點vector*/ | |
} | |
inline int size(){return root?root->s:0;} | |
}; | |
#endif |
回覆刪除卦長想看你寫的二元搜尋樹修改的code.
複雜度要跟插入刪除一樣不然就沒意義了