题意
给你一棵n
个点的树,边权都是1。设树的直径为D
。你可以选至少两个点染色,求染色的点两两之间的距离都是D
的染色方案数,模998244353。
思路
答案往往藏在题干里。我们必须围绕树的直径来思考。但是我们已掌握的关于树的直径的模型,只有一个:把树的直径看成一条链,链上每个点带一个子树。就像一条藤条上长着一些花朵。
这里采用的模型是,以树的直径的中心为分界点,把树分为两份,并设x
和y
为树的直径的两端(任取即可)。
如果D
为奇数
因为树的直径的中心处于边的中间,所以我们也可以说成是以连接pt1
和pt2
的边为分界,把树分为两棵子树。
对于左侧子树(右侧同理)的每一个点v
,dis[v,y] = dis[v,pt1] + dis[pt1,pt2] + dis[pt2,y] = dis[v,pt1] + (D+1)/2 <= D
,所以dis[v,pt1] <= (D-1)/2
,所以如果设左侧子树的另一点为v0
,则dis[v,v0] <= D-1
。于是有:被染色的点不可能处于同一棵子树。进而很快得知,被染色的点,只能是左侧子树和右侧子树各选1个,且它们分别距离pt1
和pt2
是最远的:(D-1)/2
,乘法原理即知答案。
这种情况对应我代码的solve1()
。
如果D
为偶数
此时树的直径的中心就是图中的pt1
。有了奇数情况的经验,我们很快能推广出求法。
同理,以pt1
为中心,每个pt1
的邻居作为根,划分出子树。同理可证被染色的点不可能处于同一棵子树,并且被染色的点一定是距离它所在的子树的根最远的。
于是建模出这种情况的计数问题如下:有k
个盒子,盒子有x[1],...,x[k]
个球,所有球两两不相同。每个盒子你可以选0到1个球,选出的球的个数>=2。
我们用减法来解决它。如果没有个数限制,则方案数是(x[1]+1)*...*(x[k]+1)
,如果选出0个球,方案数为1,如果选出1个球,方案数为x[1]+...+x[k]
。所以答案为(x[1]+1)*...*(x[k]+1)-1-(x[1]+...+x[k])
。
这种情况对应我代码的solve2()
。
代码实现的小技巧
奇数情况和偶数情况都要先找到分界点,这玩意我看很多人写得很复杂,但其实不需要写那么复杂。它们的共性就是距离可以表达为D>>1
,所以我们dfs的过程中维护深度,并且以该子树是否存在“树的直径的另一端点”为返回值。于是在u
的深度为D>>1
且已知v
的子树存在树的直径的另一端点的时候,就找到了分界点是u
和v
。
参考:官方题解
代码
#include <bits/stdc++.h>
using namespace std;
typedef long long LL;
#define rep(i,a,b) for(int i = (a);i <= (b);++i)
#define re_(i,a,b) for(int i = (a);i < (b);++i)
#define dwn(i,a,b) for(int i = (a);i >= (b);--i)
const int N = 2e5 + 5;
const int mod = 998244353;
int n,dis[N],maxd,x,y;int pt1,pt2;
vector<int> G[N];
template<typename Type>inline void read(Type &xx){
Type f = 1;char ch;xx = 0;
for(ch = getchar();ch < '0' || ch > '9';ch = getchar()) if(ch == '-') f = -1;
for(;ch >= '0' && ch <= '9';ch = getchar()) xx = xx * 10 + ch - '0';
xx *= f;
}
int bfs(int st){
rep(i,1,n) dis[i] = -1;
queue<int> q;
q.push(st);dis[st] = 0;
while(!q.empty()){
int u = q.front();q.pop();
for(int v: G[u]){
if(~dis[v]) continue;
dis[v] = dis[u]+1;
q.push(v);
}
}
maxd = *max_element(dis+1,dis+n+1);
rep(i,1,n) if(dis[i] == maxd) return i;
return 0;
}
bool dfs1(int u,int ufa,int want,int dep = 0){
if(u == want) return true;
bool ret = false;
for(int v: G[u]){
if(v == ufa) continue;
bool has = dfs1(v,u,want,dep+1);
ret |= has;
if(has && dep == (maxd >> 1)){
pt1 = u;pt2 = v;return true;
}
}
return ret;
}
int dfs2(int u,int ufa,int target,int dep = 0){
if(dep == target) return 1;
int ret = 0;
for(int v: G[u]){
if(v == ufa) continue;
ret += dfs2(v,u,target,dep+1);
}
return ret;
}
LL solve1(){
dfs1(x,0,y);
return 1LL * dfs2(pt1,pt2,maxd >> 1) * dfs2(pt2,pt1,maxd >> 1) % mod;
}
LL solve2(){
dfs1(x,0,y);
LL ans = 1,s = 1;
for(int v: G[pt1]){
int val = dfs2(v,pt1,(maxd >> 1) - 1);
ans = ans * (1 + val) % mod;
s = (s + val) % mod;
}
return (ans - s + mod) % mod;
}
int main(int argc, char** argv) {
read(n);
re_(i,1,n){
int x,y;read(x);read(y);
G[x].push_back(y);
G[y].push_back(x);
}
x = bfs(1);
y = bfs(x);
printf("%lld\n",(maxd&1) ? solve1() : solve2());
return 0;
}