链接: Tree
题意:
求一棵树中包含每个点的联通点集的数量,答案对 1e9+7取模。
思路:
- 大致思路是,求出每个点的子树内包含该点的联通点集数量,然后换根就可以得到每个点的答案,但因为换根的时候涉及到除法,而分母为 0 的时候是不能求逆元的,所以还要把这个节点特殊考虑一下。
- 求子树内包含每个节点的点集合=数很好求 ,就是
dp1[u]=(dp1[u]%mod*(dp1[to]+1)%mod)%mod;
- 然后考虑换根,当前节点的答案为 (子树外的答案+1)* 子树内的答案,子树外的答案可以由当前节点的上一个节点的答案得到。
pr[u]=(ans[pre]%mod*poww((dp1[u]+1),mod-2))%mod;
ans[u]=((pr[u]+1)%mod*dp1[u]%mod)%mod;
- 考虑特殊情况不能用除法的,子树外的答案其实就是 父节点不包含该节点的的答案,所以可以用乘法把它算出来,只需要枚举子节点乘起来就好了,当然还有子树外的之前已经得到,可以直接用。
- 如果不用除法,全部用 特殊情况推,也是可以得到答案的,但是复杂度有点高 会 tle.
代码:
#include<iostream>
#include<cstdio>
#include<map>
#include<math.h>
#include<queue>
#include<cstring>
#include<algorithm>
using namespace std;
typedef long long ll;
const int maxn=1e6+7;
const int mod=1e9+7;
int head[maxn],num,n;
ll dp1[maxn],ans[maxn],pr[maxn],fa[maxn];
struct node{
int to,next;
}e[maxn<<1];
long long poww(long long a,long long b){
long long ans=1;
while(b>0){
if(b&1) ans=ans*a%mod;
a=a*a%mod;
b>>=1;
}
return ans;
}
void add(int u,int v){
e[num].next=head[u];
e[num].to=v;
head[u]=num++;
}
void dfs1(int u,int pre){
dp1[u]=1;
for(int i=head[u];i!=-1;i=e[i].next){
int to=e[i].to;
if(to==pre) continue;
fa[to]=u;
dfs1(to,u);
dp1[u]=(dp1[u]%mod*(dp1[to]+1)%mod)%mod;
}
}
void dfs2(int u,int pre){
if(u==1) ans[u]=dp1[u]%mod;
else {
if((dp1[u]+1)%mod==0){
pr[u]=(pr[pre]+1);
for(int i=head[pre];i!=-1;i=e[i].next){
int v=e[i].to;
if(v==fa[pre]||v==u) continue;
pr[u]=(pr[u]%mod*(dp1[v]+1)%mod)%mod;
}
}
else pr[u]=(ans[pre]%mod*poww((dp1[u]+1),mod-2))%mod;
ans[u]=((pr[u]+1)%mod*dp1[u]%mod)%mod;
}
for(int i=head[u];i!=-1;i=e[i].next){
int to=e[i].to;
if(to==pre) continue;
dfs2(to,u);
}
}
int main (){
cin>>n;
memset(head,-1,sizeof(head));
for(int i=0,u,v;i<n-1;i++){
scanf("%d%d",&u,&v);
add(u,v);
add(v,u);
}
dfs1(1,-1);
dfs2(1,-1);
for(int i=1;i<=n;i++) printf ("%lld\n",ans[i]%mod);
}
Tle 代码:
#include<iostream>
#include<cstdio>
#include<map>
#include<math.h>
#include<queue>
#include<cstring>
#include<algorithm>
using namespace std;
typedef long long ll;
const int maxn=1e6+7;
const int mod=1e9+7;
int head[maxn],num,n;
ll dp1[maxn],ans[maxn],pr[maxn],fa[maxn];
struct node{
int to,next;
}e[maxn<<1];
long long poww(long long a,long long b){
long long ans=1;
while(b>0){
if(b&1) ans=ans*a%mod;
a=a*a%mod;
b>>=1;
}
return ans;
}
void add(int u,int v){
e[num].next=head[u];
e[num].to=v;
head[u]=num++;
}
void dfs1(int u,int pre){
dp1[u]=1;
for(int i=head[u];i!=-1;i=e[i].next){
int to=e[i].to;
if(to==pre) continue;
fa[to]=u;
dfs1(to,u);
dp1[u]=(dp1[u]%mod*(dp1[to]+1)%mod)%mod;
}
}
void dfs2(int u,int pre){
if(u==1) ans[u]=dp1[u]%mod;
else {
pr[u]=(pr[pre]+1); //父节点子树外的答案
for(int i=head[pre];i!=-1;i=e[i].next){
int v=e[i].to;
if(v==fa[pre]||v==u) continue; // 跳过当前儿子节点,和父节点
pr[u]=(pr[u]%mod*(dp1[v]+1)%mod)%mod;
}
ans[u]=((pr[u]+1)%mod*dp1[u]%mod)%mod;
}
for(int i=head[u];i!=-1;i=e[i].next){
int to=e[i].to;
if(to==pre) continue;
dfs2(to,u);
}
}
int main (){
cin>>n;
memset(head,-1,sizeof(head));
for(int i=0,u,v;i<n-1;i++){
scanf("%d%d",&u,&v);
add(u,v);
add(v,u);
}
dfs1(1,-1);
dfs2(1,-1);
for(int i=1;i<=n;i++) printf ("%lld\n",ans[i]%mod);
}