题目
给一棵树,每条边有权.求一条简单路径,权值和等于K,且边的数量最小.N <= 200000, K <= 1000000
分析
开一个100W的数组t,t[i]表示权值为i的路径最少边数
找到重心分成若干子树后, 得出一棵子树的所有点到根的权值和x,到根a条边,用t[k-x]+a更新答案,全部查询完后,再用所有a更新t[x],这样可以保证不出现点分治中的不合法情况。
把一棵树的所有子树搞完后再遍历所有子树恢复T数组,如果用memset应该会比较慢
然后我初始化打错了。。。
code
#include<cstdio>
#include<cstring>
#include<iostream>
#include<algorithm>
#include<queue>
#include<string>
#include<cmath>
using namespace std;
const int maxn=200000+10;
const int INF=1000000000;
struct arr{
int x,y;
int w;
int next;
}edge[maxn*2];
int ls[maxn];
int edge_m;
int t[1000010];
int ans;
int n,k;
bool done[maxn];
void add(int x,int y,int w)
{
edge[++edge_m]=(arr){x,y,w,ls[x]},ls[x]=edge_m;
edge[++edge_m]=(arr){y,x,w,ls[y]},ls[y]=edge_m;
}
int sz[maxn];
int f[maxn];
int rt,size;
void getrt(int x,int fa)
{
sz[x]=1;f[x]=0;
for(int i=ls[x];i;i=edge[i].next)
{
int u=edge[i].y;
if(u==fa||done[u]) continue;
getrt(u,x);
sz[x]+=sz[u];
f[x]=max(f[x],sz[u]);
}
f[x]=max(f[x],size-sz[x]);
if(f[x]<f[rt]) rt=x;
}
int deep[maxn];
int dis[maxn];
void dfs(int x,int r)
{
if (dis[x]<=k) ans=min(ans,deep[x]+t[k-dis[x]]);
for (int i=ls[x];i;i=edge[i].next)
{
if (done[edge[i].y]) continue;
if (edge[i].y==r) continue;
dis[edge[i].y]=dis[x]+edge[i].w;
deep[edge[i].y]=deep[x]+1;
dfs(edge[i].y,x);
}
}
void clean(int x,int r,int flag)
{
if (dis[x]<=k)
{
if (flag) t[dis[x]]=min(t[dis[x]],deep[x]);
else t[dis[x]]=INF;
}
for (int i=ls[x];i;i=edge[i].next)
{
if (edge[i].y==r) continue;
if (done[edge[i].y]) continue;
clean(edge[i].y,x,flag);
}
}
void cale(int x)
{
for (int i=ls[x];i;i=edge[i].next)
{
if (!done[edge[i].y])
{
deep[edge[i].y]=1;
dis[edge[i].y]=edge[i].w;
dfs(edge[i].y,0);
clean(edge[i].y,x,1);
}
}
for (int i=ls[x];i;i=edge[i].next)
{
if (done[edge[i].y]) continue;
clean(edge[i].y,0,0);
}
}
int work(int x)
{
done[x]=1;
t[0]=0;
cale(x);
for (int i=ls[x];i;i=edge[i].next)
{
if (done[edge[i].y]) continue;
f[0]=size=sz[edge[i].y];
getrt(edge[i].y,rt=0);
work(rt);
}
}
int main()
{
scanf("%d%d",&n,&k);
for (int i=1;i<=k;i++) t[i]=n;
memset(ls,0,sizeof(ls));
memset(edge,0,sizeof(edge));
memset(done,0,sizeof(done));
memset(sz,0,sizeof(sz));
memset(f,0,sizeof(f));
edge_m=0;
for (int i=1;i<n;i++)
{
int x,y,w;
scanf("%d%d%d",&x,&y,&w);
x++; y++;
add(x,y,w);
}
f[0]=ans=size=n;
getrt(1,rt=0);
work(rt);
if (ans==n) printf("-1");
else printf("%d\n",ans);
}