链接
题解
第一种情况, b b b是 a a a的祖先,这种情况对答案的贡献是 m i n ( k , d e p t h [ a ] − 1 ) × ( s i z e [ a ] − 1 ) min(k,depth[a]-1) \times (size[a]-1) min(k,depth[a]−1)×(size[a]−1),其中 s i z e size size是子树大小
第二种情况,
b
b
b是
a
a
a的后代,枚举
b
b
b,使得
b
b
b的深度和
a
a
a的深度相差不超过
k
k
k,这种情况对答案的贡献为
∑
b
(
s
i
z
e
b
−
1
)
[
d
e
p
t
h
a
<
d
e
p
t
h
b
≤
d
e
p
t
h
a
+
k
]
\sum_{b} (size_b-1) [depth_a < depth_b \le depth_a + k]
b∑(sizeb−1)[deptha<depthb≤deptha+k]
可以发现这个东西是和深度有关的一个和,可以直接长链剖分之后,一边
d
p
dp
dp一边求,除了普通的
d
p
dp
dp之外,顺便维护一个后缀和,就可以保证查询是
O
(
1
)
O(1)
O(1)了
代码
#include <bits/stdc++.h>
#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/tree_policy.hpp>
#define iinf 0x3f3f3f3f
#define linf (1ll<<60)
#define eps 1e-8
#define maxn 1000010
#define maxe 1000010
#define cl(x) memset(x,0,sizeof(x))
#define rep(i,a,b) for(i=a;i<=b;i++)
#define drep(i,a,b) for(i=a;i>=b;i--)
#define em(x) emplace(x)
#define emb(x) emplace_back(x)
#define emf(x) emplace_front(x)
#define fi first
#define se second
#define de(x) cerr<<#x<<" = "<<x<<endl
using namespace std;
using namespace __gnu_pbds;
typedef long long ll;
typedef pair<int,int> pii;
typedef pair<ll,ll> pll;
ll read(ll x=0)
{
ll c, f(1);
for(c=getchar();!isdigit(c);c=getchar())if(c=='-')f=-f;
for(;isdigit(c);c=getchar())x=x*10+c-0x30;
return f*x;
}
struct Graph
{
int etot, head[maxn], to[maxe], next[maxe], w[maxe];
void clear(int N)
{
for(int i=1;i<=N;i++)head[i]=0;
etot=0;
}
void adde(int a, int b, int c=0){to[++etot]=b;w[etot]=c;next[etot]=head[a];head[a]=etot;}
#define forp(_,__) for(auto p=__.head[_];p;p=__.next[p])
}G;
struct Longest_Chain_Decomposition
{
ll tot, len[maxn], son[maxn], depth[maxn], istop[maxn];
void dfs(Graph& G, ll u, ll fa)
{
son[u]=0;
len[u]=1;
depth[u]=depth[fa]+1;
istop[u]=false;
forp(u,G)
{
ll v(G.to[p]); if(v==fa)continue;
dfs(G,v,u);
if(len[v]+1>len[u])len[u]=len[v]+1, son[u]=v;
}
forp(u,G)
{
ll v(G.to[p]); if(v==fa or v==son[u])continue;
istop[v]=true;
}
}
void run(Graph& G, ll root)
{
tot=0;
depth[0]=0;
dfs(G,root,0);
istop[root]=true;
}
}lcd;
ll *f[maxn], *s[maxn], tot, pool[maxn], sz[maxn], p[maxn], k[maxn], ans[maxn];
vector<ll> qlis[maxn];
void dfs(ll u, ll fa)
{
ll i;
sz[u]=1;
if(lcd.istop[u])
{
f[u] = pool + tot;
tot += lcd.len[u];
s[u] = pool + tot;
tot += lcd.len[u];
}
if(lcd.son[u])
{
f[lcd.son[u]] = f[u] + 1;
s[lcd.son[u]] = s[u] + 1;
dfs(lcd.son[u],u);
sz[u]+=sz[lcd.son[u]];
}
forp(u,G)
{
ll v(G.to[p]); if(v==fa or v==lcd.son[u])continue;
dfs(v,u);
sz[u] += sz[v];
rep(i,0,lcd.len[v]-1)f[u][i+1]+=f[v][i];
drep(i,lcd.len[v],0)
{
if(i+1<lcd.len[u])s[u][i]=s[u][i+1]+f[u][i];
else s[u][i]=f[u][i];
}
}
f[u][0]=sz[u]-1;
if(lcd.len[u]>1)s[u][0]=s[u][1]+f[u][0];
else s[u][0]=f[u][0];
for(auto id:qlis[u])
{
auto K = k[id];
ans[id] = min(K,lcd.depth[u]-1) * (sz[u]-1);
if(lcd.len[u]>1)
{
if(K+1<lcd.len[u])ans[id] += s[u][1] - s[u][K+1];
else ans[id] += s[u][1];
}
}
}
int main()
{
ll i, u, v, n, q;
n = read(), q = read();
rep(i,1,n-1)
{
u = read(), v = read();
G.adde(u,v), G.adde(v,u);
}
lcd.run(G,1);
rep(i,1,q)
{
p[i]=read(), k[i]=read();
qlis[p[i]].emb(i);
}
dfs(1,0);
rep(i,1,q)
{
printf("%lld\n",ans[i]);
}
return 0;
}