这道题很像topcoder里的一道题kingdomtour,是它的弱化版,可以看我的那道题的博客,一个树型dp,复杂度 O(nk2) O ( n k 2 )
#include <cstdio>
#include <iostream>
#include <cstring>
#include <string>
#include <cstdlib>
#include <utility>
#include <cctype>
#include <algorithm>
#include <bitset>
#include <set>
#include <map>
#include <vector>
#include <queue>
#include <deque>
#include <stack>
#include <cmath>
#define LL long long
#define LB long double
#define x first
#define y second
#define Pair pair<int,int>
#define pb push_back
#define pf push_front
#define mp make_pair
#define LOWBIT(x) x & (-x)
using namespace std;
const int MOD=998244353;
const LL LINF=2e16;
const int INF=1e9;
const int magic=348;
const double eps=1e-10;
inline int getint()
{
char ch;int res;bool f;
while (!isdigit(ch=getchar()) && ch!='-') {}
if (ch=='-') f=false,res=0; else f=true,res=ch-'0';
while (isdigit(ch=getchar())) res=res*10+ch-'0';
return f?res:-res;
}
int n,K;
vector<int> v[100048];
int dp[100048][10],tmp[100048][10];
inline void dfs(int cur,int father)
{
int i,j,k,y,cc=0;
for (i=0;i<int(v[cur].size());i++)
{
y=v[cur][i];
if (y!=father) dfs(y,cur);
}
for (i=0;i<=K*2;i++) tmp[cc][i]=0;
for (i=0;i<int(v[cur].size());i++)
{
y=v[cur][i];
if (y!=father)
{
cc++;
for (j=0;j<=K*2;j++) tmp[cc][j]=INF;
for (j=0;j<=K*2;j++)
for (k=0;k<=j;k++)
tmp[cc][j]=min(tmp[cc][j],tmp[cc-1][j-k]+dp[y][k]+((k&1)?1:2));
}
}
for (i=0;i<=K*2;i++) dp[cur][i]=tmp[cc][i];
}
int main ()
{
int i,x,y;
n=getint();K=getint();
for (i=1;i<=n-1;i++)
{
x=getint();y=getint();
v[x].pb(y);v[y].pb(x);
}
dfs(1,-1);
printf("%d\n",dp[1][K*2]+K);
return 0;
}
上网看了一些题解,发现k比较小的时候,有一种找直径的算法非常巧妙
先考虑k=1的情况
我们发现,连接u,v两个点之后,u~v的路径上的点都只会被访问一次,所以要使得减少的边最多,我们应该找树的直径,这个比较显然
再考虑k=2的情况,我们发现,如果两条链有公共部分,那么这个公共部分的边还是要走两次的
进而我们发现,如果两条链有公共部分,那么一定可以在使答案不变坏的前提下换一种方案,使得两条链没有公共部分,大概是现在的第2条链的一半和第1条链的一半接起来形成新链之类的
于是我们这样做:
先对原树找一条直径
然后把直径上的边的边权从1改成-1
然后再对原树找一条直径
这两次的答案合起来就是最优方案
我们发现这个把1改成-1的操作很像网络流里面的反向边,选择这条边相当于把原来选的直径退掉
*注意当一棵树内的边有负权的时候,不能通过两次dfs找直径,得老老实实写个树型dp(其实反而更好写?雾)
#include <cstdio>
#include <iostream>
#include <cstring>
#include <string>
#include <cstdlib>
#include <utility>
#include <cctype>
#include <algorithm>
#include <bitset>
#include <set>
#include <map>
#include <vector>
#include <queue>
#include <deque>
#include <stack>
#include <cmath>
#define LL long long
#define LB long double
#define x first
#define y second
#define Pair pair<int,int>
#define pb push_back
#define pf push_front
#define mp make_pair
#define LOWBIT(x) x & (-x)
using namespace std;
const int MOD=998244353;
const LL LINF=2e16;
const int INF=1e9;
const int magic=348;
const double eps=1e-10;
inline int getint()
{
char ch;int res;bool f;
while (!isdigit(ch=getchar()) && ch!='-') {}
if (ch=='-') f=false,res=0; else f=true,res=ch-'0';
while (isdigit(ch=getchar())) res=res*10+ch-'0';
return f?res:-res;
}
int n,k;
vector<int> v[100048];
struct Edge
{
int x,y;
int len;
}edge[100048];
int ans=0;
int sum[100048],fa[100048],faind[100048];
inline int Get(int ind,int cur)
{
if (edge[ind].x==cur) return edge[ind].y; else return edge[ind].x;
}
inline void dfs(int cur,int father)
{
int i,y;
for (i=0;i<int(v[cur].size());i++)
{
y=Get(v[cur][i],cur);
if (y!=father)
{
sum[y]=sum[cur]+edge[v[cur][i]].len;
fa[y]=cur;faind[y]=v[cur][i];
dfs(y,cur);
}
}
}
inline int find_dia()
{
sum[1]=0;fa[1]=-1;dfs(1,-1);
int maxn=-INF,maxpos,i;
for (i=1;i<=n;i++)
if (sum[i]>maxn)
{
maxn=sum[i];
maxpos=i;
}
sum[maxpos]=0;fa[maxpos]=-1;dfs(maxpos,-1);
maxn=-INF;
for (i=1;i<=n;i++)
if (sum[i]>maxn)
{
maxn=sum[i];
maxpos=i;
}
return maxpos;
}
inline void update(int cur)
{
while (fa[cur]!=-1)
{
edge[faind[cur]].len=-1;
cur=fa[cur];
}
}
int res=0,dp[100048];
inline void Dfs(int cur,int father)
{
int i,y;dp[cur]=0;
for (i=0;i<int(v[cur].size());i++)
{
y=Get(v[cur][i],cur);
if (y!=father)
{
Dfs(y,cur);
res=max(res,dp[y]+edge[v[cur][i]].len+dp[cur]);
dp[cur]=max(dp[cur],dp[y]+edge[v[cur][i]].len);
}
}
}
int main ()
{
int i,x,y,ans=0;
n=getint();k=getint();
for (i=1;i<=n-1;i++)
{
x=getint();y=getint();
v[x].pb(i);v[y].pb(i);
edge[i]=Edge{x,y,1};
}
int ed=find_dia();
ans+=sum[ed];
if (k==1) {printf("%d\n",(n-1)*2-ans+1);return 0;}
update(ed);
Dfs(1,-1);ans+=res;
printf("%d\n",(n-1)*2-ans+2);
return 0;
}