好恶心的一道最小生成树......(dalao勿喷,我觉得是最小生成树)
我居然写了140行,
真是感觉自己没救了。
题干中,并没有说滑雪一定是一条路走T个点,而是可以反向,所以说,一棵树也可视为一条路,然后就有了最小生成树。
从每个点相四周加边,并记录高度差,
然后做最小生成树,当一个起点在树中并且树的size达到了T的时候,那当前的边长就是等级了。
那么如何处理在某棵并查集中的起点呢?
从dalao哪里学了一招链表,然而TLE了。dalao说是单向的,然而我比较笨,只能用双向的。
然后我惊奇地发现,居然能在merge的时候维护链表(虽然因为这件事WA了几次,但是这是值得happy一下的),
链表只存起点(题目中给的点),
每一次merge,把两个链表相接,
每一次获得一个解,就把那一个元素去掉。(节省很大时间)。
然后相加。
还是感觉自己的代码很丑······。
#include<iostream>
#include<cstdio>
#include<cstring>
#include<string>
#include<cmath>
#include<cstdlib>
#include<algorithm>
using namespace std;
int g[501][501];
struct node
{
int point;int next;
};
node list[500*500+666];
int first[500*500+666];
int last[500*500+666];
struct edge
{
int f,t;
long long dis;
};
int dx[5]={0,0,0,1,-1};
int dy[5]={0,1,-1,0,0};
int hh,tt;
long long ans;
edge e[500*500*4+666];
int p;
int fa[500*500+666];
int size[500*500+666];
bool yes[500*500+666];
int n,m;int k;
int listp;
int id[501][501];
void add(int a,int b,long long dis)
{
p++;
e[p].f=a;
e[p].t=b;
e[p].dis=dis;
}
int find(int a)
{
if(fa[a]==a)return a;
else return fa[a]=find(fa[a]);
}
inline int ra()
{
int x=0;char ch=getchar();int flag=1;
while(ch>'9'||ch<'0'){if(ch=='-')flag=-1;ch=getchar();}
while(ch>='0'&&ch<='9'){x*=10;x+=ch-'0';ch=getchar();}
return x*flag;
}
void merge(int a,int b)
{
int aa=find(a);
int bb=find(b);
fa[bb]=aa;
size[aa]+=size[bb];
size[bb]=0;
if(first[aa]==0&&first[bb]!=0)
{
first[aa]=first[bb];
last[aa]=last[bb];
first[bb]=0;
last[bb]=0;
}
else if(first[bb]!=0)
{
list[last[aa]].next=first[bb];
first[bb]=0;
last[aa]=last[bb];
last[bb]=0;
}
}
bool cmp(edge x,edge y)
{
return x.dis<y.dis;
}
int main()
{
listp=5;
n=ra();m=ra();k=ra();
for(int i=1;i<=n;i++)
for(int j=1;j<=m;j++)
g[i][j]=ra(),id[i][j]=i*500+j;
for(int i=1;i<=n;i++)
for(int j=1;j<=m;j++)
{
int flag=ra();
if(flag)yes[i*500+j]=1;
}
for(int i=1;i<=n;i++)
for(int j=1;j<=m;j++)
{
for(int l=1;l<=4;l++)
{
int nx=i+dx[l];int ny=j+dy[l];
if(nx<1||nx>n||ny>m||ny<1)continue;
add(id[i][j],id[nx][ny],abs(g[i][j]-g[nx][ny]));
}
}
hh=tt=0;
sort(e+1,e+p+1,cmp);
for(int i=1;i<=n;i++)
for(int j=1;j<=m;j++)
{
fa[i*500+j]=i*500+j;
size[id[i][j]]=1;
if(yes[id[i][j]])
{
listp++;
list[listp].point=id[i][j];
first[id[i][j]]=listp;
last[id[i][j]]=listp;
}
}
long long now=0;
for(int i=1;i<=p;i++)
{
int a=e[i].f;int b=e[i].t;
now=max(now,e[i].dis);
int aa=find(a);int bb=find(b);
if(aa!=bb)
{
merge(aa,bb);
if(size[aa]>=k&&(first[aa]!=0))
{
for(int j=first[aa];j;j=list[j].next)
if(yes[list[j].point])
{
ans+=now;
first[aa]=list[j].next;
yes[list[j].point]=0;
}
}
}
}
cout<<ans<<endl;
return 0;
}