解题思路
PS:投诉出题人,竟然卡dfs。
首先一个很显然的结论是,每块的大小一定是n的约数,
我们考虑一下将原树看做一个有根树,一个节点可以作一个块的”根”,当且仅当该节点的 size 能被块的大小整除 预处理出每个节点的 size,枚举树的大小 k,判断 size 为 k 的倍数的节点数量是否为 n / k n/k n/k就好了
代码
暴力
#include<iostream>
#include<cstdio>
#include<algorithm>
#include<iomanip>
#include<cstring>
#include<cmath>
#include<map>
#include<queue>
#define ll long long
#define ldb long double
using namespace std;
int n,k,ans,x,y,l[1000010],num[1000010],head[2000010],q1[1000010],q2[1000010],m[10000010],v[1000010];
struct c{
int x,next;
}a[2000010];
int read() {
int x=0,f=1;
char c=getchar();
while(c<'0'||c>'9'){if(c=='-') f=-1;c=getchar();}
while(c>='0'&&c<='9') x=x*10+c-'0',c=getchar();
return x*f;
}
void write(int x) {
if(x<0) putchar('-'),x=-x;
if(x>9) write(x/10);
putchar(x%10+'0');
}
void add(int x,int y)
{
a[++k].x=y;
a[k].next=head[x];
head[x]=k;
}
void dfs2(int x,int fa){
v[x]=1;
for(int i=head[x];i;i=a[i].next)
{
if(a[i].x!=fa&&!m[i])
{
dfs2(a[i].x,x);
v[x]+=v[a[i].x];
}
}
}
bool check(){
memset(v,0,sizeof(v));
int k=0;
for(int i=1;i<=n;i++)
{
if(!v[i])
{
dfs2(i,0);
if(i==1) k=v[i];
else if(v[i]!=k)return 0;
}
}
return 1;
}
void dfs(int dep)
{
if(dep>n-1)
{
if(check())
ans++;
return;
}
dfs(dep+1);
m[q1[dep]]=m[q2[dep]]=1;
dfs(dep+1);
m[q1[dep]]=m[q2[dep]]=0;
}
int main(){
n=read();
for(int i=1;i<=n-1;i++)
{
x=read();y=read();
add(x,y);
q1[i]=k;
add(y,x);
q2[i]=k;
}
dfs(1);
cout<<ans<<endl;
}
BFS正解
#include<iostream>
#include<cstdio>
#include<algorithm>
#include<iomanip>
#include<cstring>
#include<cmath>
#include<map>
#include<queue>
#define ll long long
#define ldb long double
using namespace std;
int n,k,ans,x,y,l[1000010],num[1000010],head[2000010],q[1000010],fa[1000010];
struct c{
int x,next;
}a[2000010];
int read() {
int x=0,f=1;
char c=getchar();
while(c<'0'||c>'9'){if(c=='-') f=-1;c=getchar();}
while(c>='0'&&c<='9') x=x*10+c-'0',c=getchar();
return x*f;
}
void write(int x) {
if(x<0) putchar('-'),x=-x;
if(x>9) write(x/10);
putchar(x%10+'0');
}
void add(int x,int y)
{
a[++k].x=y;
a[k].next=head[x];
head[x]=k;
}
void bfs()
{
int h=0,t=1;
q[1]=1;
while(h<t)
{
h++;
int x=q[h];
l[x]=1;
for(int i=head[x];i;i=a[i].next)
{
int y=a[i].x;
if(y!=fa[x])
{
fa[y]=x;
q[++t]=y;
}
}
}
}
int main(){
n=read();
for(int i=1;i<=n-1;i++)
{
x=read();y=read();
add(x,y);
add(y,x);
}
bfs();
for(int i=n;i>0;i--)
{
int u=q[i];
for(int j=head[u];j;j=a[j].next)
{
if(a[j].x!=fa[u])
l[u]+=l[a[j].x];
}
}
for(int i=1;i<=n;i++)
num[l[i]]++;
for(int i=1;i<=n;i++)
{
if(n%i==0)
{
int s=0;
for(int k=1;k<=n/i;k++)
s+=num[k*i];
if(s==n/i)
ans++;
}
}
printf("%d",ans);
}