题意:
给出一棵树,每次操作可以删除一个点,操作过程中会生成数组 a a a , a i a_i ai 代表删除节点 i i i 时与 i i i 相邻的节点还有多少未被删除。当删完所有节点时,生成了一个长度为 n n n的数组 a a a ,求对于 1 ≤ k ≤ n 1\leq k \leq n 1≤k≤n ,有多少不同的数组 a a a 满足: g c d ( a 1 , a 2 , a 3 , . . . ) = k gcd(a_1,a_2,a_3,...)=k gcd(a1,a2,a3,...)=k 。
题解:
初看此题,确实无从下手。对于这种问题,可以考虑往每条边的贡献思考。
对于相邻节点 ( u , v ) (u,v) (u,v),他们之间连了一条边,假设 u u u节点先被删除,那么这条边的贡献就给了 u u u ,那么 a [ u ] + = 1 a[u]+=1 a[u]+=1 ,反之如果先删除了 v v v ,那么 a [ v ] + = 1 a[v]+=1 a[v]+=1 。所以不会出现重复的序列 a a a。
考虑构建一个数组 n u m [ i ] num[i] num[i] 表示 数组 a a a的每个元素都能除以 i i i 的个数。
那么显然, n u m [ 1 ] = 2 n − 1 num[1]=2^{n-1} num[1]=2n−1 ,因为所有数都可以整除1,且对于每条边都有两种选择。
还可以发现,只有当 ( n − 1 ) % k = = 0 (n-1)\%k==0 (n−1)%k==0 时,才会有答案。因为 a 1 + a 2 + a 3 + . . . + a n = n − 1 a_1+a_2+a_3+...+a_n=n-1 a1+a2+a3+...+an=n−1 ,所有元素都能整除 k k k ,那么 n − 1 n-1 n−1肯定也要能整除 k k k 。所有我们在枚举 k k k 的时候,只需枚举 n − 1 n-1 n−1的因子即可。
对于一个固定的 k k k ,如何求出 n u m [ k ] num[k] num[k] , d f s dfs dfs这棵树,考虑当前节点 u u u与父亲节点 f a fa fa之间的这条边。
1. 1. 1. 如果 a [ u ] % k = = 0 a[u]\%k==0 a[u]%k==0 ,那么我们只能将这条边贡献给 f a fa fa。
2. 2. 2. 如果 a [ u ] % k ! = 0 a[u]\%k!=0 a[u]%k!=0 ,那么这条边必定只能给 u u u , 但如果 ( a [ u ] + 1 ) % k ! = 0 (a[u]+1)\%k!=0 (a[u]+1)%k!=0 ,那么一定就构造不出来
求出 n u m num num数组后,还要将所有倍数减去,才是题目要求的答案,即 n u m [ i ] − = ∑ j = 2 i , j + = i n n u m [ j ] num[i]-=\sum\limits_{j=2i,j+=i}^{n} num[j] num[i]−=j=2i,j+=i∑nnum[j]。
代码:
#pragma GCC diagnostic error "-std=c++11"
#include<cstdio>
#include<iostream>
#include<algorithm>
#include<cstring>
#include<cmath>
#include<queue>
#include<map>
#include<stack>
#include<set>
#include<ctime>
#define iss ios::sync_with_stdio(false)
using namespace std;
typedef unsigned long long ull;
typedef long long ll;
typedef pair<int,int> pii;
const int mod=998244353;
const int MAXN=2e5+5;
const int inf=0x3f3f3f3f;
std::vector<int> g[MAXN];
ll dp[MAXN];
ll num[MAXN];
int flag;
ll qpow(ll a,ll b){
ll res=1;
while(b){
if(b&1) res=res*a%mod;
a=a*a%mod;
b>>=1;
}
return res;
}
void dfs(int u,int f,int k)
{
if(flag==1) return;
for(auto j:g[u]){
if(j!=f) dfs(j,u,k);
}
if(dp[u]%k==0){
dp[f]++;
}
else{
if(f) dp[u]++;
if(dp[u]%k!=0) flag=1;
}
if(flag==1) return;
}
int main()
{
int t;
scanf("%d",&t);
while(t--){
int n;
scanf("%d",&n);
for(int i=1;i<=n-1;i++){
int u,v;
scanf("%d%d",&u,&v);
g[u].push_back(v);
g[v].push_back(u);
}
num[1]=qpow(2,n-1);
for(int i=2;i<=n;i++){
if((n-1)%i==0){
flag=0;
dfs(1,0,i);
num[i]=(flag^1);
for(int j=1;j<=n;j++) dp[j]=0;
}
}
for(int i=n;i>=1;i--){
for(int j=2*i;j<=n;j+=i){
num[i]-=num[j];
}
}
for(int i=1;i<=n;i++){
num[i]=(num[i]+mod)%mod;
printf("%lld ",num[i]);
}
printf("\n");
for(int i=1;i<=n;i++){
g[i].clear();
num[i]=0;
}
}
}