题目链接:点击这里
题目大意:
给一颗
n
n
n 个节点的树,边权为
1
1
1 ,求
∑
i
=
1
n
∑
j
=
1
n
d
i
s
2
(
i
,
j
)
\sum_{i=1}^n\sum_{j=1}^ndis^2(i,j)
∑i=1n∑j=1ndis2(i,j) 对
998244353
998244353
998244353 取模的结果
题目分析:
d
i
s
(
i
,
j
)
=
d
e
p
i
+
d
e
p
j
−
2
d
e
p
l
c
a
(
i
,
j
)
dis(i,j)=dep_i+dep_j-2dep_{lca(i,j)}
dis(i,j)=depi+depj−2deplca(i,j)
d
i
s
2
(
i
,
j
)
=
(
d
e
p
i
+
d
e
p
j
)
2
−
4
d
e
p
l
c
a
(
i
,
j
)
(
d
e
p
i
+
d
e
p
j
)
+
4
d
e
p
l
c
a
(
i
,
j
)
2
dis^2(i,j)=(dep_i+dep_j)^2-4dep_{lca(i,j)}(dep_i+dep_j)+4dep^2_{lca(i,j)}
dis2(i,j)=(depi+depj)2−4deplca(i,j)(depi+depj)+4deplca(i,j)2
=
d
e
p
i
2
+
d
e
p
j
2
+
2
d
e
p
i
d
e
p
j
−
4
d
e
p
l
c
a
(
i
,
j
)
(
d
e
p
i
+
d
e
p
j
)
+
4
d
e
p
l
c
a
(
i
,
j
)
2
=dep_i^2+dep_j^2+2dep_idep_j-4dep_{lca(i,j)}(dep_i+dep_j)+4dep^2_{lca(i,j)}
=depi2+depj2+2depidepj−4deplca(i,j)(depi+depj)+4deplca(i,j)2
可以发现前
3
3
3 项是可以通过预处理各个节点深度后
O
(
n
)
O(n)
O(n) 求解,此时问题转化成了如何求后两项
我们可以转换思路,通过枚举
l
c
a
(
i
,
j
)
=
x
lca(i,j)=x
lca(i,j)=x ,求出
l
c
a
(
i
,
j
)
=
x
lca(i,j)=x
lca(i,j)=x 的对数和
d
e
p
i
+
d
e
p
j
dep_i+dep_j
depi+depj 就可以求解出上式的答案了
我们定义
s
i
z
i
siz_i
sizi 表示根为
i
i
i 的子树的大小,
s
u
m
i
sum_i
sumi 表示根为
i
i
i 的子树的深度之和
那么
l
c
a
(
i
,
j
)
=
x
lca(i,j)=x
lca(i,j)=x 对数即为
s
i
z
x
∗
s
i
z
i
−
∑
y
∈
s
o
n
x
s
i
z
y
∗
s
i
z
y
siz_x*siz_i-\sum_{y\in son_x}siz_y*siz_y
sizx∗sizi−∑y∈sonxsizy∗sizy ,此式子可以理解为我们先假设
x
x
x 与其子树的所有节点都相关联,再从中减去其所有子树中之和子树相关联的点的个数
这些点对的
d
e
p
i
+
d
e
p
j
dep_i+dep_j
depi+depj 的和为
2
∑
y
∈
s
o
n
x
s
u
m
y
∗
(
s
i
z
x
−
s
i
z
y
)
+
2
d
e
p
x
∗
s
i
z
x
2\sum_{y\in son_x} sum_y*(siz_x-siz_y)+2dep_x*siz_x
2∑y∈sonxsumy∗(sizx−sizy)+2depx∗sizx
前半部分乘
2
2
2 是因为
(
i
,
j
)
,
(
j
,
i
)
(i,j),(j,i)
(i,j),(j,i) 这样的点对是都要计数的,后半部分是点
x
x
x 自身产生的贡献
具体细节见代码:
#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<cmath>
#include<vector>
#include<set>
#include<map>
#define ll long long
#define inf 0x3f3f3f3f
#define Inf 0x3f3f3f3f3f3f3f3f
#define int ll
using namespace std;
int read()
{
int res = 0,flag = 1;
char ch = getchar();
while(ch<'0' || ch>'9')
{
if(ch == '-') flag = -1;
ch = getchar();
}
while(ch>='0' && ch<='9')
{
res = (res<<3)+(res<<1)+(ch^48);//res*10+ch-'0';
ch = getchar();
}
return res*flag;
}
const int maxn = 1e6+5;
const int maxm = 3e4+5;
const int mod = 998244353;
const double pi = acos(-1);
const double eps = 1e-8;
struct Edge{
int nxt,to;
}edge[maxn<<1];
int n,cnt,ans,head[maxn],dep[maxn],siz[maxn],sum[maxn];
void addedge(int from,int to)
{
edge[++cnt].nxt = head[from];
edge[cnt].to = to;
head[from] = cnt;
}
void dfs1(int u,int fa)
{
siz[u] = 1;dep[u] = dep[fa]+1;sum[u] = dep[u];
for(int i = head[u];i;i = edge[i].nxt)
{
int to = edge[i].to;
if(to == fa) continue;
dfs1(to,u);
siz[u] += siz[to];
sum[u] += sum[to];
}
}
void dfs2(int u,int fa)
{
int num = siz[u]*siz[u],res = 2*dep[u]*siz[u]%mod;
for(int i = head[u];i;i = edge[i].nxt)
{
int to = edge[i].to;
if(to == fa) continue;
dfs2(to,u);
num -= siz[to]*siz[to];
res += 2*sum[to]*(siz[u]-siz[to])%mod;
}
ans = ((ans+4*num*dep[u]%mod*dep[u]%mod-4*dep[u]*res%mod)%mod+mod)%mod;
}
signed main()
{
n = read();
for(int i = 1;i < n;i++)
{
int u = read(),v = read();
addedge(u,v); addedge(v,u);
}
dfs1(1,0);
int tmp = 0;
for(int i = 1;i <= n;i++) tmp = (tmp+dep[i])%mod,ans = (ans+dep[i]*dep[i]%mod*2*n%mod)%mod;
ans = (ans+2*tmp*tmp%mod)%mod;
dfs2(1,0);
printf("%lld\n",ans);
return 0;
}