Description
Input
Output
Sample Input
Sample Output
Data Constraint
分析
第2问其实没有什么特殊的地方,所以这里只讲第1问的做法,然后第2问的思路也是一样的。
对于这种环套树的问题,可以考虑先删去环上的一条边,这样它就变成了一棵树,然后统计树上的答案,最后加上必须通过删去的边的答案。
树上的做法
这类问题很容易联想到点剖。
对于当前一个根为x的子树,用s[i]表示深度为i的节点的个数,初始为0。然后枚举x的每一个子树,然后bfs一下,统计答案,遍历完这个子树后再加到s数组里。
环上的做法
设环上的点为p1,p2……pn,其中(pn,p1)是上面删去的那条边。对于两个点(u,v),其中u是pi的子节点,v是pj的子节点。设dep[i]为节点i到环上任意一点的最短距离。如图所示:
那么它能被统计入答案,当且仅当满足所有下列条件:
1. dep[u]+dep[v]+i+n-j≤k //经过删去的边
2. dep[u]+dep[v]+(j-i)>k //不经过删去的边
3. i< j
但是可以发现,当i≥j时,j-i≤0,然后无论何时此时都满足i+n-j>0。此时如果满足第2个条件,那么第一个条件就一定不满足。所以可以无视掉第3个条件。
然后设x[u]=dep[u]+i(u为i的子节点),y[u]=dep[u]-i,那么条件变成:
1. x[u]≤k-n-y[v]
2. y[u]>k-x[v]
统计答案就可以用主席树来完成了。
注意细节
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
const int maxn=100005,maxm=200005;
typedef long long LL;
int n,m,tot,id,root,Size,V,N,s[maxn],h[maxn],e[maxm],next[maxm],v[maxn],data[maxn],size[maxn],dep[maxn],E[maxn],ID[maxm];
int fa[maxn],D[maxn],Root[maxm],L[maxm*20],R[maxm*20],s1[maxm*20];
LL ans1,ans2,t[maxn],s2[maxm*20];
bool visit[maxn],bz[maxn];
struct op
{
int X,Y,e;
}A[maxm];
bool cmp(op a,op b)
{
return a.X<b.X;
}
void add(int x,int y,int id)
{
e[++tot]=y; ID[tot]=id; next[tot]=h[x]; h[x]=tot;
}
void find(int x)
{
visit[x]=1;
for (int i=h[x];i;i=next[i])
if (!visit[e[i]])
{
bz[ID[i]]=1;
find(e[i]);
}
}
void calc(int x,int y)
{
size[x]=1;
int M=0;
for (int i=h[x];i;i=next[i]) if (ID[i]!=id && e[i]!=y && !visit[e[i]])
{
calc(e[i],x);
size[x]+=size[e[i]];
M=max(M,size[e[i]]);
}
M=max(M,Size-size[x]);
if (M<V)
{
V=M; root=x;
}
}
void get_dep(int x,int y)
{
for (int i=h[x];i;i=next[i]) if (ID[i]!=id && e[i]!=y && !visit[e[i]])
{
dep[e[i]]=dep[x]+1;
get_dep(e[i],x);
}
}
void dfs(int x)
{
V=Size;
calc(x,0);
int R=root,i,j,k,cnt,Cnt=0;
LL Sum,SS=0;
visit[R]=1;
dep[R]=0;
get_dep(R,0);
for (i=h[R];i;i=next[i]) if (ID[i]!=id && !visit[e[i]])
{
data[tot=1]=e[i];
cnt=Cnt; Sum=SS;
for (j=1;j<=tot;j++)
{
if (dep[data[j]]<m)
{
if (dep[data[j-1]]<dep[data[j]])
{
cnt-=s[m-dep[data[j-1]]];
Sum-=t[m-dep[data[j-1]]];
}
ans1+=cnt;
ans2+=Sum*v[data[j]];
}else break;
for (k=h[data[j]];k;k=next[k]) if (ID[k]!=id && !visit[e[k]] && dep[e[k]]==dep[data[j]]+1)
data[++tot]=e[k];
}
for (j=1;j<=tot;j++)
{
s[dep[data[j]]]++; t[dep[data[j]]]+=v[data[j]];
Cnt++; SS+=v[data[j]];
}
}
ans1+=Cnt; ans2+=SS*v[R];
data[tot=1]=R;
for (j=1;j<=tot;j++)
{
if (dep[data[j]]==m) break;
for (k=h[data[j]];k;k=next[k]) if (ID[k]!=id && !visit[e[k]] && dep[e[k]]==dep[data[j]]+1)
data[++tot]=e[k];
}
for (j=2;j<=tot;j++)
{
s[dep[data[j]]]--; t[dep[data[j]]]-=v[data[j]];
}
for (i=h[R];i;i=next[i]) if (ID[i]!=id && !visit[e[i]])
{
data[Size=1]=e[i];
for (j=1;j<=Size;j++)
for (k=h[data[j]];k;k=next[k]) if (ID[k]!=id && !visit[e[k]] && dep[e[k]]==dep[data[j]]+1)
data[++Size]=e[k];
dfs(e[i]);
}
}
bool Find(int x,int s,int y)
{
visit[x]=1;
D[s]=x;
if (x==y)
{
V=s; return 1;
}
for (int i=h[x];i;i=next[i]) if (ID[i]!=id && !visit[e[i]] && Find(e[i],s+1,y)) return 1;
return 0;
}
void insert(int l,int r,int g,int v,int y,int &x)
{
if (bz)
{
x=++tot;
s1[x]=s1[y]; s2[x]=s2[y];
L[x]=L[y]; R[x]=R[y];
}
s1[x]++; s2[x]+=v;
if (l==r) return;
int mid=(l+r)/2;
if (g<=mid) insert(l,mid,g,v,L[y],L[x]);
else insert(mid+1,r,g,v,R[y],R[x]);
}
int get1(int l,int r,int g,int x)
{
if (l==g || !x) return s1[x];
int mid=(l+r)/2;
if (g>mid) return get1(mid+1,r,g,R[x]);
return get1(l,mid,g,L[x])+s1[R[x]];
}
LL get2(int l,int r,int g,int x)
{
if (l==g || !x) return s2[x];
int mid=(l+r)/2;
if (g>mid) return get2(mid+1,r,g,R[x]);
return get2(l,mid,g,L[x])+s2[R[x]];
}
int get(int x)
{
int l=1,r=N,mid;
for (mid=(l+r)/2;l<r;mid=(l+r)/2)
if (A[mid].X<=x) l=mid+1;else r=mid;
if (A[l].X>x) l--;
return l;
}
int main()
{
freopen("pronet.in","r",stdin); freopen("pronet.out","w",stdout);
scanf("%d%d",&n,&m);
for (int i=1;i<=n;i++)
{
scanf("%d",&E[i]);
add(i,E[i],i); add(E[i],i,i);
}
find(1);
for (id=1;bz[id];id++);
for (int i=1;i<=n;i++) scanf("%d",&v[i]);
memset(visit,0,sizeof(visit));
Size=n;
dfs(1);
memset(visit,0,sizeof(visit));
Find(id,1,E[id]);
memset(bz,0,sizeof(bz));
for (int i=1;i<=V;i++) bz[D[i]]=1;
N=0;
for (int i=1;i<=V;i++)
{
A[++N].X=i; A[N].Y=-i; A[N].e=v[D[i]]; dep[D[i]]=0;
for (int x=h[D[i]];x;x=next[x]) if (!bz[e[x]])
{
data[tot=1]=e[x]; fa[e[x]]=D[i];
for (int j=1;j<=tot;j++)
{
dep[data[j]]=dep[fa[data[j]]]+1;
if (dep[data[j]]>m) break;
A[++N].X=dep[data[j]]+i; A[N].Y=dep[data[j]]-i; A[N].e=v[data[j]];
for (int k=h[data[j]];k;k=next[k]) if (e[k]!=fa[data[j]])
{
data[++tot]=e[k]; fa[e[k]]=data[j];
}
}
}
}
sort(A+1,A+N+1,cmp);
A[0].X=Root[0]=s1[0]=s2[0]=tot=0;
memset(L,0,sizeof(L)); memset(R,0,sizeof(R));
for (int i=1;i<=N;i++) insert(0,m+V,A[i].Y+V,A[i].e,Root[i-1],Root[i]);
for (int i=1;i<=N;i++)
{
int y=m-A[i].X+1,x=get(m-V-A[i].Y);
if (x>0 && y<=m)
{
y=max(0,y+V);
ans1+=get1(0,m+V,y,Root[x]);
ans2+=get2(0,m+V,y,Root[x])*A[i].e;
}
}
printf("%lld %lld\n",ans1,ans2);
fclose(stdin); fclose(stdout);
return 0;
}