题目大意
一个
n
个点的环套树,每个点有点权
1≤n≤105,k≤n,ei≤104
题目分析
Algorithm Alpha
我比赛时想到的就是这种方法。
我们删除环上一条边,将其变成一棵树,做点分治。如果不考虑删除的边的影响,这就是一道点剖经典题(楼天城:《树中点对个数》),弄出到分治中心的距离,排序后统计即可。但是我们在这里还需要考虑这条边的影响。我的方法是对楼天城那题的扩展,所以很多东西我都省用原本的算法,省略介绍。
可以列出:答案=删边后满足要求点对个数-删边前后都满足要求点对个数+删边前满足要求点对个数。
第一个在点分治中直接计算即可。
处理后两个时,判重是个很棘手的问题,我是采用以下方法。令删除的边为
(p1,p2)
,我们将所有点分成两个集合,
A
集满足
那么第二个东西的计算,我们可以在点剖中顺便进行,对于排好序的点,我们使用树状数组维护当前有多少
B
集合的点满足
第三个部分就很简单了,直接将所有
时间复杂度
Algorithm Beta
其实这题直接将整个环去掉会更加简便。
对每棵子树点剖,然后再加上环,使用树状数组在环上统计即可。
具体实现请读者自行思考。
时间复杂度
O(nlog2n)
。
代码实现
吐槽:模仿 STL 的 sort ,第一次用指针打的桶排。
#include <iostream>
#include <cstring>
#include <cstdio>
#include <cctype>
#include <cmath>
using namespace std;
typedef long long LL;
int read()
{
int x=0,f=1;
char ch=getchar();
while (!isdigit(ch))
{
if (ch=='-') f=-1;
ch=getchar();
}
while (isdigit(ch))
{
x=x*10+ch-'0';
ch=getchar();
}
return x*f;
}
const int N=100500;
const int EL=N<<1;
const int LGEL=18;
const int M=N<<1;
int size[N],fa[N],last[N],t[2][N],da[N],db[N],a[N],h[N],pos[N],d[N];
int n,k,tot,e1,e2,p1,p2,el,lgel;
int tov[M],next[M],rev[M];
int rmq[EL][LGEL];
int euler[EL];
bool vis[N];
LL cnt,ans;
int getrmq(int l,int r)
{
int lgr=trunc(log(r-l+1)/log(2));
if (h[rmq[l][lgr]]<h[rmq[r-(1<<lgr)+1][lgr]])
return rmq[l][lgr];
else
return rmq[r-(1<<lgr)+1][lgr];
}
int lca(int x,int y)
{
x=pos[x],y=pos[y];
if (x>y) x^=y^=x^=y;
return getrmq(x,y);
}
int dist(int x,int y)
{
int z=lca(x,y);
return h[x]+h[y]-h[z]*2;
}
bool find(int x,int e)
{
vis[x]=true;
int i=last[x],y;
while (i)
{
y=tov[i];
if (!vis[y])
{
bool g=find(y,i);
if (g) return g;
}
else
if (rev[e]!=i)
return e1=i,e2=rev[i],p1=tov[e2],p2=tov[e1];
i=next[i];
}
return false;
}
void dfs(int x)
{
euler[pos[x]=++el]=x;
int i=last[x],y;
while (i)
{
y=tov[i];
if (i!=e1&&i!=e2&&y!=fa[x])
h[y]=h[fa[y]=x]+1,dfs(y),euler[++el]=x;
i=next[i];
}
}
void pre()
{
lgel=trunc(log(el)/log(2))+1;
for (int i=1;i<=el;i++)
rmq[i][0]=euler[i];
for (int j=1;j<=lgel;j++)
for (int i=1;i+(1<<j)-1<=el;i++)
if (h[rmq[i][j-1]]<h[rmq[i+(1<<j-1)][j-1]])
rmq[i][j]=rmq[i][j-1];
else
rmq[i][j]=rmq[i+(1<<j-1)][j-1];
for (int i=1;i<=n;i++)
da[i]=dist(i,p1),db[i]=dist(i,p2);
}
int que[N],head,tail;
int core(int rt)
{
int x,y,i;
head=0,tail=1,que[1]=rt,fa[rt]=0;
while (head!=tail)
{
size[x=que[++head]]=1,i=last[x];
while (i)
{
y=tov[i];
if (!vis[y]&&y!=fa[x]&&i!=e1&&i!=e2)
fa[que[++tail]=y]=x;
i=next[i];
}
}
for (head=tail;head;head--)
size[fa[que[head]]]+=size[que[head]];
int mi=n,ret,tmp;
for (head=1;head<=tail;head++)
{
x=que[head],i=last[x],tmp=0;
while (i)
{
y=tov[i];
if (!vis[y]&&y!=fa[x]&&i!=e1&&i!=e2)
tmp=max(tmp,size[y]);
i=next[i];
}
tmp=max(tmp,size[rt]-size[x]);
if (tmp<mi)
ret=x,mi=tmp;
}
return ret;
}
int q1[N],q2[N];
bool cmp(int x,int y)
{
return d[x]<d[y];
}
int lowbit(int x){return x&-x;}
void add(int x,int y,LL edit)
{
y++;
while (y<=n)
{
t[x][y]+=edit;
y+=lowbit(y);
}
}
LL query(int x,int y)
{
LL ret=0;
y++;
while (y>0)
{
ret+=t[x][y];
y-=lowbit(y);
}
return ret;
}
int sss[N],num[N];
void sort(int *st,int *en,int mxl)
{
for (int i=0;i<=mxl;i++) sss[i]=0;
for (int *cur=st;cur!=en;cur++)
sss[d[*cur]]++;
for (int i=1;i<=mxl;i++) sss[i]+=sss[i-1];
for (int *cur=st,i=1;cur!=en;cur++,i++)
num[sss[d[*cur]]--]=*cur;
for (int *cur=st,i=1;cur!=en;i++,cur++)
*cur=num[i];
}
void calc(int x)
{
LL tcnt=0,tans=0;
int c=core(x),i0=last[c],y0,i,y,mx=0,mx0=0;
q1[0]=1,q1[1]=c,d[c]=0;
while (i0)
{
y0=tov[i0],q2[0]=0;
if (!vis[y0]&&i0!=e1&&i0!=e2)
{
head=0,tail=1,que[1]=y0,fa[y0]=c;
while (head!=tail)
{
x=que[++head],d[x]=d[fa[x]]+1;
if (d[x]>mx0) mx0=d[x];
if (d[x]>mx) mx=d[x];
q1[++q1[0]]=x,q2[++q2[0]]=x;
i=last[x];
while (i)
{
y=tov[i];
if (!vis[y]&&y!=fa[x]&&i!=e1&&i!=e2)
fa[que[++tail]=y]=x;
i=next[i];
}
}
sort(q2+1,q2+1+q2[0],mx0);
int sum=0;
for (int i=1;i<=q2[0];i++)
{
sum+=a[q2[i]];
if (da[q2[i]]>db[q2[i]])
add(0,db[q2[i]],1),add(1,db[q2[i]],a[q2[i]]);
}
int cur=q2[0];
for (int i=1;i<=q2[0];i++)
{
while (cur&&d[q2[cur]]+d[q2[i]]>k)
{
if (da[q2[cur]]>db[q2[cur]])
add(0,db[q2[cur]],-1),add(1,db[q2[cur]],-a[q2[cur]]);
sum-=a[q2[cur--]];
}
tcnt-=cur;
tans-=1ll*a[q2[i]]*sum;
if (cur&&da[q2[i]]<db[q2[i]])
{
cnt+=query(0,k-1-da[q2[i]]);
ans+=a[q2[i]]*1ll*query(1,k-1-da[q2[i]]);
}
}
while (cur)
{
if (da[q2[cur]]>db[q2[cur]])
add(0,db[q2[cur]],-1),add(1,db[q2[cur]],-a[q2[cur]]);
cur--;
}
}
i0=next[i0];
}
sort(q1+1,q1+1+q1[0],mx);
int sum=0;
for (int i=1;i<=q1[0];i++)
{
sum+=a[q1[i]];
if (da[q1[i]]>db[q1[i]])
add(0,db[q1[i]],1),add(1,db[q1[i]],a[q1[i]]);
}
int cur=q1[0];
for (int i=1;i<=q1[0];i++)
{
while (cur&&d[q1[cur]]+d[q1[i]]>k)
{
if (da[q1[cur]]>db[q1[cur]])
add(0,db[q1[cur]],-1),add(1,db[q1[cur]],-a[q1[cur]]);
sum-=a[q1[cur--]];
}
tcnt+=cur;
tans+=1ll*a[q1[i]]*sum;
if (cur&&da[q1[i]]<db[q1[i]])
{
cnt-=query(0,k-1-da[q1[i]]);
ans-=a[q1[i]]*1ll*query(1,k-1-da[q1[i]]);
}
}
while (cur)
{
if (da[q1[cur]]>db[q1[cur]])
add(0,db[q1[cur]],-1),add(1,db[q1[cur]],-a[q1[cur]]);
cur--;
}
tcnt--,tans-=a[c]*1ll*a[c];
cnt+=tcnt/2,ans+=tans/2;
vis[c]=true,i=last[c];
while (i)
{
y=tov[i];
if (!vis[y]&&i!=e1&&i!=e2)
calc(y);
i=next[i];
}
}
void solve()
{
for (int i=1;i<=n;i++)
if (da[i]>db[i])
add(0,db[i],1),add(1,db[i],a[i]);
for (int i=1;i<=n;i++)
if (da[i]<db[i])
cnt+=query(0,k-1-da[i]),ans+=a[i]*1ll*query(1,k-1-da[i]);
}
void insert(int x,int y)
{
tov[++tot]=y;
next[tot]=last[x];
last[x]=tot;
}
int main()
{
freopen("pronet.in","r",stdin);
freopen("pronet.out","w",stdout);
n=read(),k=read();
for (int i=1,x;i<=n;i++)
{
x=read();
insert(i,x),rev[tot]=tot+1;
insert(x,i),rev[tot]=tot-1;
}
for (int i=1;i<=n;i++) a[i]=read();
find(1,0),h[1]=1,dfs(1),pre();
memset(vis,0,sizeof vis),calc(1),solve();
printf("%lld %lld\n",cnt,ans);
fclose(stdin);
fclose(stdout);
return 0;
}