在網路上曾經看到這段話:
"實際上KM算法的複雜度是可以做到\ord{N^3}的。我們給每個y頂點一個“鬆弛量”函數slack,
每次開始找增廣路時初始化為無窮大。在尋找增廣路的過程中,檢查邊(i,j)時,如果它不在相等子圖中,則讓slack[j]變成原值與lx[i]+ly[j]-w[i,j]的較小值。這樣,在修改頂標時,取所有不在交錯樹中的Y頂點的slack值中的最小值作為d值即可。但還要注意一點:修 改頂標後,要把所有不在交錯樹中的Y頂點的slack值都減去d。"
這段話其實是錯的,請看底下的code:
\ord{N^4}常數優化版本
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#define MAXN 100 | |
#define INF INT_MAX | |
int n; | |
int g[MAXN][MAXN],lx[MAXN],ly[MAXN],slack_y[MAXN]; | |
int match_y[MAXN]; | |
bool vx[MAXN],vy[MAXN]; | |
bool dfs(int x){ | |
if(vx[x])return 0; | |
vx[x]=1; | |
for(int y=0,t;y<n;++y){ | |
if(vy[y])continue; | |
t=lx[x]+ly[y]-g[x][y]; | |
if(t==0){ | |
vy[y]=1; | |
if(match_y[y]==-1||dfs(match_y[y])){ | |
match_y[y]=x; | |
return 1; | |
} | |
}else if(slack_y[y]>t)slack_y[y]=t; | |
} | |
return 0; | |
} | |
inline int km(){ | |
memset(ly,0,sizeof(int)*n); | |
memset(match_y,-1,sizeof(int)*n); | |
for(int x=0;x<n;++x){ | |
lx[x]=-INF; | |
for(int y=0;y<n;++y){ | |
lx[x]=max(lx[x],g[x][y]); | |
} | |
} | |
for(int x=0;x<n;++x){ | |
for(int y=0;y<n;++y)slack_y[y]=INF; | |
for(;;){ | |
memset(vx,0,sizeof(bool)*n); | |
memset(vy,0,sizeof(bool)*n); | |
if(dfs(x))break; | |
int cut=INF; | |
for(int y=0;y<n;++y){ | |
if(!vy[y]&&cut>slack_y[y])cut=slack_y[y]; | |
} | |
for(int j=0;j<n;++j){ | |
if(vx[j])lx[j]-=cut; | |
if(vy[j])ly[j]+=cut; | |
else slack_y[j]-=cut; | |
} | |
} | |
} | |
int ans=0; | |
for(int y=0;y<n;++y)if(g[match_y[y]][y]!=-INF)ans+=g[match_y[y]][y]; | |
return ans; | |
} |
第一層是一個for迴圈,執行N次;第二層有一個無線迴圈,但她最多執行N次;裡面有一個DFS,最多執行\ord{N^2}次,還有一些迴圈但是不影響複雜度所以不討論。
總複雜度是\ord{N}*\ord{N}*\ord{N^2}為\ord{N^4},因此網路上大多數\ord{N^3}的code其實是錯的。
接下來看一下我修改後真正的\ord{N^3}code:
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#define MAXN 100 | |
#define INF INT_MAX | |
int n; | |
int g[MAXN][MAXN],lx[MAXN],ly[MAXN],slack_y[MAXN]; | |
int match_y[MAXN];//要保證g是完全二分圖 | |
bool vx[MAXN],vy[MAXN]; | |
bool dfs(int x,bool adjust=1){//DFS找增廣路,is=1表示要擴充 | |
if(vx[x])return 0; | |
vx[x]=1; | |
for(int y=0;y<n;++y){ | |
if(vy[y])continue; | |
int t=lx[x]+ly[y]-g[x][y]; | |
if(t==0){ | |
vy[y]=1; | |
if(match_y[y]==-1||dfs(match_y[y],adjust)){ | |
if(adjust)match_y[y]=x; | |
return 1; | |
} | |
}else if(slack_y[y]>t)slack_y[y]=t; | |
} | |
return 0; | |
} | |
inline int km(){ | |
memset(match_y,-1,sizeof(int)*n); | |
memset(ly,0,sizeof(int)*n); | |
for(int x=0;x<n;++x){ | |
lx[x]=-INF; | |
for(int y=0;y<n;++y){ | |
lx[x]=max(lx[x],g[x][y]); | |
} | |
} | |
for(int x=0;x<n;++x){ | |
for(int y=0;y<n;++y)slack_y[y]=INF; | |
memset(vx,0,sizeof(bool)*n); | |
memset(vy,0,sizeof(bool)*n); | |
if(dfs(x))continue; | |
bool flag=1; | |
while(flag){ | |
int cut=INF; | |
for(int y=0;y<n;++y){ | |
if(!vy[y]&&cut>slack_y[y])cut=slack_y[y]; | |
} | |
for(int j=0;j<n;++j){ | |
if(vx[j])lx[j]-=cut; | |
if(vy[j])ly[j]+=cut; | |
else slack_y[j]-=cut; | |
} | |
for(int y=0;y<n;++y){ | |
if(!vy[y]&&slack_y[y]==0){ | |
vy[y]=1; | |
if(match_y[y]==-1||dfs(match_y[y],0)){ | |
flag=0;//測試成功,有增廣路 | |
break; | |
} | |
} | |
} | |
} | |
memset(vx,0,sizeof(bool)*n); | |
memset(vy,0,sizeof(bool)*n); | |
dfs(x);//最後要記得擴充增廣路 | |
} | |
int ans=0; | |
for(int y=0;y<n;++y)if(g[match_y[y]][y]!=-INF)ans+=g[match_y[y]][y]; | |
return ans; | |
} |
主函數dfs的部分移到while的外面,而修改lx,ly和slack_y後的地方增加了一個迴圈來判斷這次的修改有沒有產生可行的增廣路,對於新增加的「等邊」,如果他有機會是增廣路的話,會對此「等邊」連像的點y進行dfs「判斷」有沒有增廣路(把adjust設成0)。如果有增廣路,就會跳出while,在進行一次dfs來擴充增廣路。
總複雜度是
\ord{N}[第一層迴圈] * ( \ord{N^2}[dfs的時間] + \ord{N}[while的次數]*\ord{N} ) = \ord{N^3}
這才是真正的\ord{N^3}算法,前者只是常數優化而已
最後是\ord{N^3}算法常數較小的版本:
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#define MAXN 100 | |
#define INF INT_MAX | |
int g[MAXN][MAXN],lx[MAXN],ly[MAXN],slack_y[MAXN]; | |
int px[MAXN],py[MAXN],match_y[MAXN],par[MAXN]; | |
int n; | |
void adjust(int y){//把增廣路上所有邊反轉 | |
match_y[y]=py[y]; | |
if(px[match_y[y]]!=-2) | |
adjust(px[match_y[y]]); | |
} | |
bool dfs(int x){//DFS找增廣路 | |
for(int y=0;y<n;++y){ | |
if(py[y]!=-1)continue; | |
int t=lx[x]+ly[y]-g[x][y]; | |
if(t==0){ | |
py[y]=x; | |
if(match_y[y]==-1){ | |
adjust(y); | |
return 1; | |
} | |
if(px[match_y[y]]!=-1)continue; | |
px[match_y[y]]=y; | |
if(dfs(match_y[y]))return 1; | |
}else if(slack_y[y]>t){ | |
slack_y[y]=t; | |
par[y]=x; | |
} | |
} | |
return 0; | |
} | |
inline int km(){ | |
memset(ly,0,sizeof(int)*n); | |
memset(match_y,-1,sizeof(int)*n); | |
for(int x=0;x<n;++x){ | |
lx[x]=-INF; | |
for(int y=0;y<n;++y){ | |
lx[x]=max(lx[x],g[x][y]); | |
} | |
} | |
for(int x=0;x<n;++x){ | |
for(int y=0;y<n;++y)slack_y[y]=INF; | |
memset(px,-1,sizeof(int)*n); | |
memset(py,-1,sizeof(int)*n); | |
px[x]=-2; | |
if(dfs(x))continue; | |
bool flag=1; | |
while(flag){ | |
int cut=INF; | |
for(int y=0;y<n;++y) | |
if(py[y]==-1&&cut>slack_y[y])cut=slack_y[y]; | |
for(int j=0;j<n;++j){ | |
if(px[j]!=-1)lx[j]-=cut; | |
if(py[j]!=-1)ly[j]+=cut; | |
else slack_y[j]-=cut; | |
} | |
for(int y=0;y<n;++y){ | |
if(py[y]==-1&&slack_y[y]==0){ | |
py[y]=par[y]; | |
if(match_y[y]==-1){ | |
adjust(y); | |
flag=0; | |
break; | |
} | |
px[match_y[y]]=y; | |
if(dfs(match_y[y])){ | |
flag=0; | |
break; | |
} | |
} | |
} | |
} | |
} | |
int ans=0; | |
for(int y=0;y<n;++y)if(g[match_y[y]][y]!=-INF)ans+=g[match_y[y]][y]; | |
return ans; | |
} |
沒有留言:
張貼留言