题目传送门:https://ac.nowcoder.com/acm/contest/3002/F
基本思路
要求只经过一个黑点的路径数量,可以分两种情况:
1.起点和终点均是白点,其余路径中只有一个黑点。
2.起点和终点为一黑一白,其余路径中全部是白点。
把树分成若干个连通块构成,连通块有两种,一是全为白点,二是只有一黑点,其他全白点(特殊地,也可以只有一黑点,无白点)。这里的“连通块”,意义与图论中“含限制条件的最大连通分量”相同。
设dp[i][0]表示以 i 为根的子树(包括 i )全为白的连通块有多少个点。
dp[i][1]表示以 i 为根的子树(包括 i )一黑其他全白的连通块有多少个点。
或者这么认为:dp[i][0]表示i的子孙节点到i路径上无黑点的个数,dp[i][1]表示i的子孙节点到i路径上有一个黑点的个数。
我们要选出起点终点分布于两个不同种类连通块的所有情况,然后每次将答案加上这两个连通块中点的个数的乘积。
这题的关键在于怎么写递归的dfs函数,以及如何更新答案ans和dp值。
用dfs进行树的后序遍历,也就是对于每个点,先从左到右访问完它的所有子树后,再访问它自己。
先更新完所有子孙节点的dp值,再更新自己的dp值。
用u表示当前访问的节点编号,在向下递归的过程中(从根到叶子),进行dp的初始化:
if(s[u]=='W')dp[u][0]=1,dp[u][1]=0;//dp[u][0]=1,现在表示u自身为白,之后回溯再更新dp
else dp[u][0]=0,dp[u][1]=1;//dp[u][1]=1,现在表示u自身为黑,之后回溯再更新dp
在进行回溯的过程中,也就是从叶子到根,向上返回的过程中,进行ans的更新:
(u表示当前节点,v表示与u直接相连的孩子节点)
ans+=dp[u][0]*dp[v][1]+dp[u][1]*dp[v][0];
更新ans之后,进行dp的更新:(可以理解为将以v点为根的子树合并到u点上,之后u点的dp值是它和它的子树一起构成的)
if(s[u]=='W')dp[u][0]+=dp[v][0],dp[u][1]+=dp[v][1];
else dp[u][0]=0,dp[u][1]+=dp[v][0];//全白加上一黑,要把全白的连通块从0开始计数,所以要把dp[u][0]变成0
完整的dfs函数代码:(用vector存图)
void dfs(int u,int fa)//u表示当前访问的节点,fa是u的父亲节点
{
if(s[u]=='W')dp[u][0]=1,dp[u][1]=0;
else dp[u][0]=0,dp[u][1]=1;
for(int i=0;i<g[u].size();i++)
{
int v=g[u][i];//v是与u直接相连的孩子节点
if(v==fa)continue;//防止向上递归
dfs(v,u);
ans+=dp[u][0]*dp[v][1]+dp[u][1]*dp[v][0];
if(s[u]=='W')dp[u][0]+=dp[v][0],dp[u][1]+=dp[v][1];
else dp[u][0]=0,dp[u][1]+=dp[v][0];
}
}
AC代码
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N=1e5+10;
char s[N];
int n,x,y;
ll dp[N][2],ans;//dp[i][0]表示全白 dp[i][1]表示一黑其他全白
vector<int>g[N];
void dfs(int u,int fa)
{
if(s[u]=='W')dp[u][0]=1,dp[u][1]=0;
else dp[u][0]=0,dp[u][1]=1;
//for(int i=0;i<g[u].size();i++)
for(int v:g[u])
{
//int v=g[u][i];
if(v==fa)continue;
dfs(v,u);
ans+=dp[u][0]*dp[v][1]+dp[u][1]*dp[v][0];
if(s[u]=='W')dp[u][0]+=dp[v][0],dp[u][1]+=dp[v][1];
else dp[u][0]=0,dp[u][1]+=dp[v][0];
}
}
int main()
{
ios::sync_with_stdio(false);
cin>>n>>s+1;
for(int i=1;i<=n-1;i++)
{
cin>>x>>y;
g[x].push_back(y);
g[y].push_back(x);
}
dfs(1,0);
printf("%lld\n",ans);
return 0;
}
新收获
关于遍历vector中的所有元素,也可以这么写:
for(int v:g[u])
{
//code...
}
等效于:
for(int i=0;i<g[u].size();i++)
{
int v=g[u][i];
//code...
}
(参考:C++11之for循环的新用法)