题目大意:给定一个值S和一棵树。在树的每个节点有一个正整数,问有多少条路径的节点总和达到S。路径中节点的深度必须是升序的。假设节点1是根节点,根的深度是0,它的儿子节点的深度为1。路径不必一定从根节点开始。
题解:倍增+二分O(nlog^2n),(Orz hzwer神奇做法)
我的收获:倍增维护其他信息还是比较字词的,Orz set
#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cstring>
using namespace std;
const int M=100005;
int n,s,t;
int head[M],w[M],f[M][22],l[M][22];
struct edge{
int to,nex;
}e[M*2];
int read(){
int x=0,f=1;char c=getchar();
while(c>'9'||c<'0') {if(c=='-') f=-1; c=getchar();}
while(c>='0'&&c<='9') x=x*10+c-48,c=getchar();
return x*f;
}
void add(int i,int j){
e[t].to=j;
e[t].nex=head[i];
head[i]=t++;
}
void dfs(int x){
for(int i=head[x];i!=-1;i=e[i].nex){
int v=e[i].to;
if(v!=f[x][0]){
f[v][0]=x;
l[v][0]=w[x];
dfs(v);
}
}
}
void st()
{
for(int j=1;j<=20;j++)
for(int i=1;i<=n;i++)
f[i][j]=f[f[i][j-1]][j-1],
l[i][j]=l[i][j-1]+l[f[i][j-1]][j-1];
}
int get_path(int x,int dep)
{
int ret=0;
for(int i=18;i>=0;i--)
{
if(dep>=(1<<i)){
dep-=(1<<i);
ret+=l[x][i];
x=f[x][i];
}
}
return ret;
}
bool solve(int x)
{
int l=0,r=10000,mid;
while(l<=r)
{
mid=(l+r)>>1;
int len=get_path(x,mid)+w[x];
if(len>s) r=mid-1;
else if(len<s) l=mid+1;
else return 1;
}
return 0;
}
void work()
{
int tot=0;
dfs(1);
st();
for(int i=1;i<=n;i++)
tot+=solve(i);
cout<<tot<<endl;
}
void init()
{
int x,y;
memset(head,-1,sizeof(head));
cin>>n>>s;
for(int i=1;i<=n;i++)
w[i]=read();
for(int i=1;i<n;i++){
x=read();y=read();
add(x,y),add(y,x);
}
}
int main()
{
init();
work();
return 0;
}
hzwer学长的set做法(Orz)
#include <iostream>
#include <cstdio>
#include <cstring>
#include <set>
#include <algorithm>
#define inf 2000000000
using namespace std;
const int M=100005;
int n,s,cnt,ans,t;
int w[M],head[M],sum[M];
multiset<int> st;
struct edge{
int to,nex;
}e[M*2];
int read()
{
int x=0,f=1;char ch=getchar();
while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
while(ch>='0'&&ch<='9'){x=x*10+ch-'0';ch=getchar();}
return x*f;
}
void add(int i,int j){
e[t].to=j;
e[t].nex=head[i];
head[i]=t++;
}
void dfs(int x,int fa)
{
if(st.find(sum[x]-s)!=st.end()) ans++;
st.insert(sum[x]);
for(int i=head[x];i!=-1;i=e[i].nex)
{
int v=e[i].to;
if(v!=fa){
sum[v]=sum[x]+w[v];
dfs(v,x);
}
}
st.erase(sum[x]);
}
void work()
{
dfs(1,0);
printf("%d\n",ans);
}
void init()
{
int u,v;
memset(head,-1,sizeof(head));
cin>>n>>s;
st.insert(0);//Orz
for(int i=1;i<=n;i++)
w[i]=read();
for(int i=1;i<n;i++)
{
u=read(),v=read();
add(u,v),add(v,u);
}
sum[1]=w[1];
}
int main()
{
init();
work();
return 0;
}