题目链接
首先,我们来看一下平方的方程
然后,我们分析一下,我们现在可以很方便的求出一个点到其余各点的距离和的平方,那么假如要转移到另一个点的时候呢?怎么整呢?(换根思想!)
譬如说是一副这样的图,我们现在从now结点跳转至next结点,那么对于next结点他的信息应该如何计算呢?就是它到其余各点的距离的平方和。
首先,对now和next进行联合分析,now转移到next的话,相当于对上方结点和对它右边的子树节点的距离都加了1,但是呢对以next为根的子树的结点的距离都是减少了1。
如果原先到上方结点的边有三条a、b、c,那么的话现在距离变成了(a + 1)、(b + 1)、(c + 1),求一下平方和,就是、、,多出来的部分就是,其中这个3等于上方的结点个数。
那么,同样的,到下方结点的距离会减少,减少的部分就是与上式同理了~所以,由于这里(a + b + c)的存在,所以,我们要处理一个到上方结点距离和的那么一个函数,并且由于还要计算到下方结点的,我们同样还要预处理一个到下方子树结点的距离和的一个数组。
#include <iostream>
#include <cstdio>
#include <cmath>
#include <string>
#include <cstring>
#include <algorithm>
#include <limits>
#include <vector>
#include <stack>
#include <queue>
#include <set>
#include <map>
#include <unordered_map>
#include <unordered_set>
#define lowbit(x) ( x&(-x) )
#define pi 3.141592653589793
#define e 2.718281828459045
#define INF 0x3f3f3f3f
#define HalF (l + r)>>1
#define lsn rt<<1
#define rsn rt<<1|1
#define Lson lsn, l, mid
#define Rson rsn, mid+1, r
#define QL Lson, ql, qr
#define QR Rson, ql, qr
#define myself rt, l, r
#define MP(a, b) make_pair(a, b)
#define Min3(a, b, c) min(a, min(b, c))
using namespace std;
typedef unsigned long long ull;
typedef unsigned int uit;
typedef long long ll;
const ll mod = 998244353;
const int maxN = 1e6 + 7;
int N, head[maxN], cnt;
struct Eddge
{
int nex, to;
Eddge(int a=-1, int b=0):nex(a), to(b) {}
}edge[maxN << 1];
inline void addEddge(int u, int v)
{
edge[cnt] = Eddge(head[u], v);
head[u] = cnt++;
}
inline void _add(int u, int v) { addEddge(u, v); addEddge(v, u); }
int siz[maxN], father[maxN];
ll down_dis[maxN], dis[maxN], segema_dis[maxN];
ll now = 0;
void pre_dfs(int u, int fa)
{
father[u] = fa;
siz[u] = 1; now = (now + dis[u] * dis[u]) % mod;
down_dis[u] = 0;
for(int i=head[u], v; ~i; i=edge[i].nex)
{
v = edge[i].to;
if(v == fa) continue;
dis[v] = dis[u] + 1;
pre_dfs(v, u);
siz[u] += siz[v];
down_dis[u] += down_dis[v] + siz[v];
down_dis[u] %= mod;
}
}
ll all = 0;
void ex_dfs(int u, ll sta, ll up_dis)
{
all = (all + sta) % mod;
ll nex_sta, nex_up_dis;
for(int i=head[u], v; ~i; i=edge[i].nex)
{
v = edge[i].to;
if(v == father[u]) continue;
nex_up_dis = up_dis + down_dis[u] - (down_dis[v] + siz[v]) + (N - siz[v]);
nex_up_dis %= mod;
nex_sta = sta - (down_dis[v] + siz[v]) * 2 + siz[v] + nex_up_dis * 2 - (N - siz[v]);
nex_sta %= mod;
ex_dfs(v, nex_sta, nex_up_dis);
}
}
inline void init()
{
cnt = 0;
for(int i=1; i<=N; i++) head[i] = -1;
}
int main()
{
scanf("%d", &N);
init();
for(int i=1, u, v; i<N; i++)
{
scanf("%d%d", &u, &v);
_add(u, v);
}
pre_dfs(1, 0);
// printf("%lld\n", now);
ex_dfs(1, now, 0);
printf("%lld\n", all);
return 0;
}
/*
5
1 2
2 3
2 4
1 5
*/