1912: [Apio2010]patrol 巡逻
Time Limit: 4 Sec Memory Limit: 64 MBSubmit: 1806 Solved: 937
[ Submit][ Status][ Discuss]
Description
Input
第一行包含两个整数 n, K(1 ≤ K ≤ 2)。接下来 n – 1行,每行两个整数 a, b, 表示村庄a与b之间有一条道路(1 ≤ a, b ≤ n)。
Output
输出一个整数,表示新建了K 条道路后能达到的最小巡逻距离。
Sample Input
8 1
1 2
3 1
3 4
5 3
7 5
8 5
5 6
1 2
3 1
3 4
5 3
7 5
8 5
5 6
Sample Output
11
HINT
10%的数据中,n ≤ 1000, K = 1;
30%的数据中,K = 1;
80%的数据中,每个村庄相邻的村庄数不超过 25;
90%的数据中,每个村庄相邻的村庄数不超过 150;
100%的数据中,3 ≤ n ≤ 100,000, 1 ≤ K ≤ 2。
解析:
由题意可知,在不连接任意边的情况下巡逻距离为2*(n-1)。
当k==1时,连接树的直径的两个端点即可,此时答案为2*(n-1)-len+1,len为直径大小。
当k==2时,求完直径len1后,把直径上的边权由1改为为-1,再求一次直径len2,此时答案为2*(n-1)-(len1-1)-(len2-1)。
代码:
#include <iostream>
#include <algorithm>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <cmath>
#include <cctype>
#include <queue>
using namespace std;
const int Max=101000;
int n,k,ans,size=1,p,len1,len2;
int dis[Max],first[Max];
int pre[Max],v[Max],d[Max];
struct shu{int next,to,len;};
shu bian[Max*2];
inline int get_int()
{
int x=0,f=1;
char c;
for(c=getchar();(!isdigit(c))&&(c!='-');c=getchar());
if(c=='-') {f=-1;c=getchar();}
for(;isdigit(c);c=getchar()) x=(x<<3)+(x<<1)+c-'0';
return x*f;
}
inline void build(int x,int y)
{
size++;
bian[size].next=first[x];
first[x]=size;
bian[size].to=y;
bian[size].len=1;
}
inline int bfs(int s)
{
queue<int>q;q.push(s);
memset(dis,0x3f,sizeof(dis));
memset(pre,0,sizeof(pre));
dis[s]=0;
while(!q.empty())
{
int point=q.front();
q.pop();
for(int u=first[point];u;u=bian[u].next)
{
if(dis[bian[u].to]==0x3f3f3f3f)
{
dis[bian[u].to]=dis[point]+bian[u].len;
pre[bian[u].to]=u;
q.push(bian[u].to);
}
}
}
int x=1;
for(int i=1;i<=n;i++) if(dis[i]>dis[x]) x=i;
return x;
}
inline int calc()
{
p=bfs(1);
p=bfs(p);
return dis[p];
}
inline void change()
{
for(;pre[p];p=bian[pre[p]^1].to) bian[pre[p]].len=bian[pre[p]^1].len=-1;
}
inline void dp(int point)
{
v[point]=1;
for(int u=first[point];u;u=bian[u].next)
if(!v[bian[u].to])
{
dp(bian[u].to);
len2=max(len2,d[point]+d[bian[u].to]+bian[u].len);
d[point]=max(d[point],d[bian[u].to]+bian[u].len);
}
}
int main()
{
// freopen("lx.in","r",stdin);
//freopen("lx.out","w",stdout);
n=get_int();
k=get_int();
for(int i=1;i<=n-1;i++)
{
int x=get_int(),y=get_int();
build(x,y);
build(y,x);
}
len1=calc();
if(k==2)
{
change();
dp(1);
cout<<2*(n-1)-(len1-1)-(len2-1)<<"\n";
}
else cout<<2*(n-1)-len1+1<<"\n";
return 0;
}