Processing math: 100%
\newcommand{\ord}[1]{\mathcal{O}\left(#1\right)} \newcommand{\abs}[1]{\lvert #1 \rvert} \newcommand{\floor}[1]{\lfloor #1 \rfloor} \newcommand{\ceil}[1]{\lceil #1 \rceil} \newcommand{\opord}{\operatorname{\mathcal{O}}} \newcommand{\argmax}{\operatorname{arg\,max}} \newcommand{\str}[1]{\texttt{"#1"}}

2016年5月19日 星期四

[ Kuhn-Munkres Algorithm ] 二分圖最大權完美匹配KM算法

關於此算法的介紹及教學請看這份投影片,這篇文章注重於討論\ord{N^4}的slack優化及\ord{N^4}\ord{N^3}的過程。

在網路上曾經看到這段話:

"實際上KM算法的複雜度是可以做到\ord{N^3}的。我們給每個y頂點一個“鬆弛量”函數slack,
每次開始找增廣路時初始化為無窮大。在尋找增廣路的過程中,檢查邊(i,j)時,如果它不在相等子圖中,則讓slack[j]變成原值與lx[i]+ly[j]-w[i,j]的較小值。這樣,在修改頂標時,取所有不在交錯樹中的Y頂點的slack值中的最小值作為d值即可。但還要注意一點:修 改頂標後,要把所有不在交錯樹中的Y頂點的slack值都減去d。"

這段話其實是錯的,請看底下的code:

\ord{N^4}常數優化版本
#define MAXN 100
#define INF INT_MAX
int n;
int g[MAXN][MAXN],lx[MAXN],ly[MAXN],slack_y[MAXN];
int match_y[MAXN];
bool vx[MAXN],vy[MAXN];
bool dfs(int x){
if(vx[x])return 0;
vx[x]=1;
for(int y=0,t;y<n;++y){
if(vy[y])continue;
t=lx[x]+ly[y]-g[x][y];
if(t==0){
vy[y]=1;
if(match_y[y]==-1||dfs(match_y[y])){
match_y[y]=x;
return 1;
}
}else if(slack_y[y]>t)slack_y[y]=t;
}
return 0;
}
inline int km(){
memset(ly,0,sizeof(int)*n);
memset(match_y,-1,sizeof(int)*n);
for(int x=0;x<n;++x){
lx[x]=-INF;
for(int y=0;y<n;++y){
lx[x]=max(lx[x],g[x][y]);
}
}
for(int x=0;x<n;++x){
for(int y=0;y<n;++y)slack_y[y]=INF;
for(;;){
memset(vx,0,sizeof(bool)*n);
memset(vy,0,sizeof(bool)*n);
if(dfs(x))break;
int cut=INF;
for(int y=0;y<n;++y){
if(!vy[y]&&cut>slack_y[y])cut=slack_y[y];
}
for(int j=0;j<n;++j){
if(vx[j])lx[j]-=cut;
if(vy[j])ly[j]+=cut;
else slack_y[j]-=cut;
}
}
}
int ans=0;
for(int y=0;y<n;++y)if(g[match_y[y]][y]!=-INF)ans+=g[match_y[y]][y];
return ans;
}
view raw KM_O(N4).cpp hosted with ❤ by GitHub
我們仔細看一下迴圈的層數。
第一層是一個for迴圈,執行N次;第二層有一個無線迴圈,但她最多執行N次;裡面有一個DFS,最多執行\ord{N^2}次,還有一些迴圈但是不影響複雜度所以不討論。
總複雜度是\ord{N}*\ord{N}*\ord{N^2}\ord{N^4},因此網路上大多數\ord{N^3}的code其實是錯的。

接下來看一下我修改後真正的\ord{N^3}code:
#define MAXN 100
#define INF INT_MAX
int n;
int g[MAXN][MAXN],lx[MAXN],ly[MAXN],slack_y[MAXN];
int match_y[MAXN];//要保證g是完全二分圖
bool vx[MAXN],vy[MAXN];
bool dfs(int x,bool adjust=1){//DFS找增廣路,is=1表示要擴充
if(vx[x])return 0;
vx[x]=1;
for(int y=0;y<n;++y){
if(vy[y])continue;
int t=lx[x]+ly[y]-g[x][y];
if(t==0){
vy[y]=1;
if(match_y[y]==-1||dfs(match_y[y],adjust)){
if(adjust)match_y[y]=x;
return 1;
}
}else if(slack_y[y]>t)slack_y[y]=t;
}
return 0;
}
inline int km(){
memset(match_y,-1,sizeof(int)*n);
memset(ly,0,sizeof(int)*n);
for(int x=0;x<n;++x){
lx[x]=-INF;
for(int y=0;y<n;++y){
lx[x]=max(lx[x],g[x][y]);
}
}
for(int x=0;x<n;++x){
for(int y=0;y<n;++y)slack_y[y]=INF;
memset(vx,0,sizeof(bool)*n);
memset(vy,0,sizeof(bool)*n);
if(dfs(x))continue;
bool flag=1;
while(flag){
int cut=INF;
for(int y=0;y<n;++y){
if(!vy[y]&&cut>slack_y[y])cut=slack_y[y];
}
for(int j=0;j<n;++j){
if(vx[j])lx[j]-=cut;
if(vy[j])ly[j]+=cut;
else slack_y[j]-=cut;
}
for(int y=0;y<n;++y){
if(!vy[y]&&slack_y[y]==0){
vy[y]=1;
if(match_y[y]==-1||dfs(match_y[y],0)){
flag=0;//測試成功,有增廣路
break;
}
}
}
}
memset(vx,0,sizeof(bool)*n);
memset(vy,0,sizeof(bool)*n);
dfs(x);//最後要記得擴充增廣路
}
int ans=0;
for(int y=0;y<n;++y)if(g[match_y[y]][y]!=-INF)ans+=g[match_y[y]][y];
return ans;
}
view raw KM_O(N3).cpp hosted with ❤ by GitHub
可以看到我在dfs裡面加了一個參數adjust,如果他是true,就跟原本的dfs沒兩樣,可以在找到增廣路後順便將增廣路擴充;如果他是false,整個dfs就變成只能判斷有沒有增廣路了。
主函數dfs的部分移到while的外面,而修改lx,ly和slack_y後的地方增加了一個迴圈來判斷這次的修改有沒有產生可行的增廣路,對於新增加的「等邊」,如果他有機會是增廣路的話,會對此「等邊」連像的點y進行dfs「判斷」有沒有增廣路(把adjust設成0)。如果有增廣路,就會跳出while,在進行一次dfs來擴充增廣路。
總複雜度是
\ord{N}[第一層迴圈] * ( \ord{N^2}[dfs的時間] + \ord{N}[while的次數]*\ord{N} ) = \ord{N^3}
這才是真正的\ord{N^3}算法,前者只是常數優化而已

最後是\ord{N^3}算法常數較小的版本:
#define MAXN 100
#define INF INT_MAX
int g[MAXN][MAXN],lx[MAXN],ly[MAXN],slack_y[MAXN];
int px[MAXN],py[MAXN],match_y[MAXN],par[MAXN];
int n;
void adjust(int y){//把增廣路上所有邊反轉
match_y[y]=py[y];
if(px[match_y[y]]!=-2)
adjust(px[match_y[y]]);
}
bool dfs(int x){//DFS找增廣路
for(int y=0;y<n;++y){
if(py[y]!=-1)continue;
int t=lx[x]+ly[y]-g[x][y];
if(t==0){
py[y]=x;
if(match_y[y]==-1){
adjust(y);
return 1;
}
if(px[match_y[y]]!=-1)continue;
px[match_y[y]]=y;
if(dfs(match_y[y]))return 1;
}else if(slack_y[y]>t){
slack_y[y]=t;
par[y]=x;
}
}
return 0;
}
inline int km(){
memset(ly,0,sizeof(int)*n);
memset(match_y,-1,sizeof(int)*n);
for(int x=0;x<n;++x){
lx[x]=-INF;
for(int y=0;y<n;++y){
lx[x]=max(lx[x],g[x][y]);
}
}
for(int x=0;x<n;++x){
for(int y=0;y<n;++y)slack_y[y]=INF;
memset(px,-1,sizeof(int)*n);
memset(py,-1,sizeof(int)*n);
px[x]=-2;
if(dfs(x))continue;
bool flag=1;
while(flag){
int cut=INF;
for(int y=0;y<n;++y)
if(py[y]==-1&&cut>slack_y[y])cut=slack_y[y];
for(int j=0;j<n;++j){
if(px[j]!=-1)lx[j]-=cut;
if(py[j]!=-1)ly[j]+=cut;
else slack_y[j]-=cut;
}
for(int y=0;y<n;++y){
if(py[y]==-1&&slack_y[y]==0){
py[y]=par[y];
if(match_y[y]==-1){
adjust(y);
flag=0;
break;
}
px[match_y[y]]=y;
if(dfs(match_y[y])){
flag=0;
break;
}
}
}
}
}
int ans=0;
for(int y=0;y<n;++y)if(g[match_y[y]][y]!=-INF)ans+=g[match_y[y]][y];
return ans;
}

沒有留言:

張貼留言