题目大意
给定一颗树,对树进行树分块使得每块点数相同,求方案数
TLE算法
容易观察出,假如块大小定了,那么至多只有一种方案。
怎么分块呢?设size[x]表示x子树中还未被分块的节点数量。
像普通size一样求。
退出x时,如果size[x]=c即块大小,那么可以形成一块,size[x]清0。
最后若size[1]为0,代表分块成功。
复杂度n根号n,TLE
#include<cstdio>
#include<algorithm>
#include<cmath>
#define fo(i,a,b) for(i=a;i<=b;i++)
using namespace std;
const int maxn=1000000+10;
int size[maxn],h[maxn],go[maxn*2],next[maxn*2];
int i,j,k,l,t,n,m,tot,ans,c;
int read(){
int x=0,f=1;
char ch=getchar();
while (ch<'0'||ch>'9'){
if (ch=='-') f=-1;
ch=getchar();
}
while (ch>='0'&&ch<='9'){
x=x*10+ch-'0';
ch=getchar();
}
return x*f;
}
void add(int x,int y){
go[++tot]=y;
next[tot]=h[x];
h[x]=tot;
}
void dfs(int x,int y){
size[x]=1;
int t=h[x];
while (t){
if (go[t]!=y){
dfs(go[t],x);
size[x]+=size[go[t]];
}
t=next[t];
}
if (size[x]==c) size[x]=0;
}
void work(int v){
c=v;
dfs(1,0);
if (!size[1]) ans++;
}
int main(){
n=read();
fo(i,1,n-1){
j=read();k=read();
add(j,k);add(k,j);
}
t=floor(sqrt(n));
fo(i,1,t)
if (n%i==0){
work(i);
if (n/i!=i) work(n/i);
}
printf("%d\n",ans);
}
AC算法
我们发现,我们会在那些原本的size(这里的size指子树大小)就是c的倍数的地方划分出新的一块。
大概意思就是,我们设s[x]表示size[x]%c。
那么我删除一个子树y满足s[y]=0而且子树y内所有除y节点s均不为0,接下来整颗树其余未删除部分的s仍然不变吧?
而回忆TLE算法,每删一次就意味着分出了一块。所以如果要分块成功,我们需要分出恰好n/c个块,也就是意味s值为0的节点要有n/c个,那么就是说——size是c的倍数的节点要有n/c个!
我们开个桶记录每种size的节点个数,枚举块大小c,然后去暴力计算其倍数size的个数,这样是n log n的。
#include<cstdio>
#include<algorithm>
#include<cmath>
#define fo(i,a,b) for(i=a;i<=b;i++)
using namespace std;
const int maxn=1000000+10;
int size[maxn],h[maxn],go[maxn*2],next[maxn*2],cnt[maxn];
int i,j,k,l,t,n,m,tot,ans,c;
int read(){
int x=0,f=1;
char ch=getchar();
while (ch<'0'||ch>'9'){
if (ch=='-') f=-1;
ch=getchar();
}
while (ch>='0'&&ch<='9'){
x=x*10+ch-'0';
ch=getchar();
}
return x*f;
}
void add(int x,int y){
go[++tot]=y;
next[tot]=h[x];
h[x]=tot;
}
void dfs(int x,int y){
int t=h[x];
size[x]=1;
while (t){
if (go[t]!=y){
dfs(go[t],x);
size[x]+=size[go[t]];
}
t=next[t];
}
}
int main(){
n=read();
fo(i,1,n-1){
j=read();k=read();
add(j,k);add(k,j);
}
dfs(1,0);
fo(i,1,n) cnt[size[i]]++;
fo(i,1,n){
if (n%i!=0) continue;
t=0;
fo(j,1,n/i) t+=cnt[i*j];
if (t==n/i) ans++;
}
printf("%d\n",ans);
}