题目
题目链接:https://ac.nowcoder.com/acm/contest/2927/E
你有一颗大小为
n
\mathit n
n 的树,点从
1
\mathit 1
1 到
n
\mathit n
n 标号。
设
dis
(
x
,
y
)
\operatorname{dis}(x,y)
dis(x,y)表示
x
\mathit x
x 到
y
\mathit y
y 的距离。
求
∑
i
=
1
n
∑
j
=
1
n
d
i
s
2
(
i
,
j
)
\sum_{i=1}^n\sum_{j=1}^n dis^2(i,j)
∑i=1n∑j=1ndis2(i,j)对998244353取模的结果。
思路:
发现全部人写
O
(
n
)
O(n)
O(n)的就我一个
O
(
n
log
n
)
O(n\log n)
O(nlogn)卡过去了还行。
考虑每一个点的贡献,我们求出每一个点为根时分别到其他点的距离平方之和。
那么换根时,只要将即将转移到的子节点内的子树距离全部减一,其他全部加一即可。
线段树维护区间和以及区间平方和。
转移时相当于
a
2
+
b
2
+
c
2
a^2+b^2+c^2
a2+b2+c2转移成了
(
a
+
1
)
2
+
(
b
+
1
)
2
+
(
c
+
1
)
2
(a+1)^2+(b+1)^2+(c+1)^2
(a+1)2+(b+1)2+(c+1)2,完全平方公式拆开即可。
时间复杂度
O
(
n
log
n
)
O(n\log n)
O(nlogn),自带大常数,不知道怎么过的。
代码:
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
typedef long long ll;
const int N=1000010,MOD=998244353;
int n,tot,head[N],dfn[N],rk[N],size[N],dep[N];
ll ans,sum;
struct edge
{
int next,to,dis;
}e[N*2];
struct Treenode
{
int l,r;
ll lazy,sum,ans;
};
struct Tree
{
Treenode tree[N*4];
void pushup(int x)
{
tree[x].ans=tree[x*2].ans+tree[x*2+1].ans;
tree[x].sum=tree[x*2].sum+tree[x*2+1].sum;
}
void addans(int x,ll val)
{
tree[x].ans+=tree[x].sum*val*2LL+val*val*(ll)(tree[x].r-tree[x].l+1);
}
void pushdown(int x)
{
if (tree[x].lazy)
{
tree[x*2].lazy+=tree[x].lazy;
tree[x*2+1].lazy+=tree[x].lazy;
addans(x*2,tree[x].lazy); addans(x*2+1,tree[x].lazy);
tree[x*2].sum+=tree[x].lazy*(ll)(tree[x*2].r-tree[x*2].l+1);
tree[x*2+1].sum+=tree[x].lazy*(ll)(tree[x*2+1].r-tree[x*2+1].l+1);
tree[x].lazy=0;
}
}
void build(int x,int l,int r)
{
tree[x].l=l; tree[x].r=r;
if (l==r)
{
tree[x].sum=dep[rk[l]];
tree[x].ans=1LL*dep[rk[l]]*dep[rk[l]];
return;
}
int mid=(l+r)>>1;
build(x*2,l,mid); build(x*2+1,mid+1,r);
pushup(x);
}
void update(int x,int l,int r,ll val)
{
if (l>r) return;
if (tree[x].l==l && tree[x].r==r)
{
tree[x].lazy+=val;
addans(x,val);
tree[x].sum+=val*(ll)(tree[x].r-tree[x].l+1);
return;
}
pushdown(x);
int mid=(tree[x].l+tree[x].r)>>1;
if (r<=mid) update(x*2,l,r,val);
else if (l>mid) update(x*2+1,l,r,val);
else update(x*2,l,mid,val),update(x*2+1,mid+1,r,val);
pushup(x);
}
}Tree;
void add(int from,int to)
{
e[++tot].to=to;
e[tot].next=head[from];
head[from]=tot;
}
void dfs1(int x,int fa)
{
dep[x]=dep[fa]+1; dfn[x]=++tot; rk[tot]=x;
for (int i=head[x];~i;i=e[i].next)
if (e[i].to!=fa)
{
dfs1(e[i].to,x);
size[x]+=size[e[i].to];
}
size[x]++;
}
void dfs2(int x,int fa)
{
ans=(ans+Tree.tree[1].ans)%MOD;
sum=Tree.tree[1].ans;
for (int i=head[x];~i;i=e[i].next)
{
int v=e[i].to;
if (v!=fa)
{
Tree.update(1,dfn[v],dfn[v]+size[v]-1,-2);
Tree.update(1,1,n,1);
dfs2(v,x);
Tree.update(1,1,n,-1);
Tree.update(1,dfn[v],dfn[v]+size[v]-1,2);
}
}
}
int main()
{
memset(head,-1,sizeof(head));
scanf("%d",&n);
for (int i=1,x,y;i<n;i++)
{
scanf("%d%d",&x,&y);
add(x,y); add(y,x);
}
dep[0]=-1; tot=0;
dfs1(1,0);
Tree.build(1,1,n);
dfs2(1,0);
printf("%lld",ans);
return 0;
}