题意:
就是给你一个树,然后问你有多少个不同的连通图,这个连通图所有度为1的点颜色都一样。
思考:
- 刚开始我以为是有多少子树是这样,那么感觉是树上启发式。实际上是问联通子图。这该怎么求啊,完全没有思路。dp?但是dp定义什么呢,我怎么定义才能让dp都存的是叶子节点呢,这好像很复杂的样子。实际上不是,dp你定义什么就是什么,定义叶子节点都是黑色或者白色的方案数,那么就是这样的。所以不要胡思乱想,直接定义找转移状态就好。
- 既然这样方案数的问题一般都是dp,这里定义dp为什么呢?题目说度为1的点的颜色都全部一样,那么我就定义dp[i][0]以i点为根的子树中度为1的所有点颜色为0的方案数(但是不包括i这个点,即i点如果度为1但也不管他是什么颜色)。dp[i][1]代表以i为根的子树中度为1的所有点颜色为1的方案数(不包括i这个点)。 现在说一下,为什么定义的时候说不包括i这个点,因为一些不合法的情况我要保留下来,保留下来给父亲用的时候这样就变成了合法情况。但是统计答案的时候自己再把这些不合法的删去就行了,但是不能在dp里删,我求的dp并不是真正的答案只是一个伪状态。
显然1-2这个子图在dp[1][0]中是不合法的,因为1的度也为1但是颜色不是0。但是这个答案我不存在dp[1][0]里面吗?还是存的,因为5-1-2这个子图是合法的,所以这个答案必须存。
代码:
#include<bits/stdc++.h>
#define fi first
#define se second
#define pb push_back
#define db double
#define int long long
#define PII pair<int,int >
#define mem(a,b) memset(a,b,sizeof(a))
#define IOS std::ios::sync_with_stdio(false),cin.tie(0),cout.tie(0);
using namespace std;
const int mod = 1e9+7,inf = 1e18;
const int N = 3e5+10,M = 2010;
int T,n,m,k;
int va[N];
char s[N];
int in[N];
int dp[N][2];
int ans;
vector<int > e[N];
void dfs(int now,int p)
{
if(s[now]=='0') dp[now][0] = 1; //如果只选now这个点
if(s[now]=='1') dp[now][1] = 1;
int sum0 = 1,sum1 = 1;
int res0 = 0,res1 = 0;
for(auto spot:e[now])
{
if(spot==p) continue;
dfs(spot,now);
sum0 = sum0*(dp[spot][0]+1)%mod; //spot的总方案数+1,1就是不选择这个子树
sum1 = sum1*(dp[spot][1]+1)%mod;
res0 = (res0+dp[spot][0])%mod; //方案数总和,因为我下面要减去每个只选根节点和某一个子树的方案数,那么加起来就是所有子树的方案数
res1 = (res1+dp[spot][1])%mod;
}
dp[now][0] = dp[now][0]+sum0-1,dp[now][1] = dp[now][1]+sum1-1; //-1是减去那种全不选的情况
ans = (ans+dp[now][0]+dp[now][1])%mod;
if(s[now]=='0') ans = (ans-res1)%mod; //由于这种不合法的情况是需要存在dp里面的,所以就在答案里面减去这种不合法的情况
if(s[now]=='1') ans = (ans-res0)%mod;
}
void init()
{
ans = 0;
for(int i=1;i<=n;i++)
{
e[i].clear();
in[i] = dp[i][0] = dp[i][1] = 0;
}
}
signed main()
{
IOS;
while(cin>>n)
{
cin>>s+1;
init();
for(int i=1;i<n;i++)
{
int a,b;
cin>>a>>b;
e[a].pb(b);
e[b].pb(a);
in[a]++,in[b]++;
}
dfs(1,0);
ans = (ans%mod+mod)%mod;
cout<<ans<<"\n";
}
return 0;
}
另一种写法:
#include<bits/stdc++.h>
#define fi first
#define se second
#define pb push_back
#define db double
#define int long long
#define PII pair<int,int >
#define mem(a,b) memset(a,b,sizeof(a))
#define IOS std::ios::sync_with_stdio(false),cin.tie(0),cout.tie(0);
using namespace std;
const int mod = 1e9+7,inf = 1e18;
const int N = 3e5+10,M = 2010;
int T,n,m,k;
int va[N];
char s[N];
int in[N];
int dp[N][2];
int ans;
vector<int > e[N];
void dfs(int now,int p)
{
if(in[now]==1)
{
if(s[now]=='0') dp[now][0] = 1;
if(s[now]=='1') dp[now][1] = 1;
return ;
}
int sum0 = 1,sum1 = 1;
for(auto spot:e[now])
{
if(spot==p) continue;
dfs(spot,now);
sum0 = sum0*(dp[spot][0]+1)%mod;
sum1 = sum1*(dp[spot][1]+1)%mod;
}
dp[now][0] = sum0,dp[now][1] = sum1;
if(s[now]=='0') dp[now][1] = (dp[now][1]-1)%mod;
if(s[now]=='1') dp[now][0] = (dp[now][0]-1)%mod;
}
void get(int now,int p)
{
int sum0 = 0,sum1 = 0;
for(auto spot:e[now])
{
if(spot==p) continue;
get(spot,now);
sum0 = (sum0+dp[spot][0])%mod;
sum1 = (sum1+dp[spot][1])%mod;
}
if(s[now]=='0') ans = (ans-sum1)%mod;
if(s[now]=='1') ans = (ans-sum0)%mod;
ans = (ans+dp[now][0]+dp[now][1])%mod;
}
void init()
{
ans = 0;
for(int i=1;i<=n;i++)
{
e[i].clear();
in[i] = dp[i][0] = dp[i][1] = 0;
}
}
signed main()
{
IOS;
while(cin>>n)
{
cin>>s+1;
init();
for(int i=1;i<n;i++)
{
int a,b;
cin>>a>>b;
e[a].pb(b);
e[b].pb(a);
in[a]++,in[b]++;
}
dfs(1,0);get(1,0);
ans = (ans%mod+mod)%mod;
cout<<ans<<"\n";
}
return 0;
}
总结:
多多思考,多多积累定义状态。