非常值得一做的树形dp.
题意:给定一棵树(n<=300),你可以给每个节点等概率地染成A,B,C三种颜色之一,对于树上的一条边,若其两个端点的颜色不一样,则断开这条边.最后对于一个特定的颜色,X为点数为奇数的联通块个数,Y是点数为偶数的联通块个数,其得分为max(0,X-Y).问最后得分的期望乘上3^n mod 1e9+7的值.
解法:注意到颜色的对称性,我们只需要求出每个颜色的期望再乘上3就可以,而期望就是所有的情况除以3^n(情况种数),所以dp出所有可能的状态的方法数即可.
dp状态:dp[i][j][k],i代表对应的点,j有3个取值,0代表不取当前的点,1代表取当前的点并且当前点所在联通块的个数为奇数个,2代表为偶数个,k代表x-y的值.(注意dp表示的个数并不包括当前根节点的状态,因为当前根节点的状态还要需要其父亲)
状态转移:
dp[u][0][x+y]=dp[u][0][x]*dp[v][0][y]+dp[u][0][x]*dp[v][1][y-1]+dp[u][0][x]*dp[v][2][y+1]
dp[u][1][x+y]=dp[u][1][x]*dp[v][0][y]+dp[u][1][x]*dp[v][2][y]+dp[u][2][x]*dp[v][1][y]
dp[u][2][x+y]=dp[u][2][x]*dp[v][0][y]+dp[u][1][x]*dp[v][1][y]+dp[u][2][x]*dp[v][2][y]
注意dp枚举的过程以及其优化.
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <vector>
using namespace std;
const int MOD = 1000000007;
const int base = 160;
vector<int> g[305];
int dp[305][3][505];
int t[3][505];
int low[305], high[305];
int n;
void add(int &x, int y)
{
x += y;
while (x < MOD) {
x += MOD;
}
while (x >= MOD) {
x -= MOD;
}
}
int make_mul(int a, int b, int c, int d, int e, int f)
{
int v = 0;
add(v, 1ll * a * b % MOD);
add(v, 1ll * c * d % MOD);
add(v, 1ll * e * f % MOD);
return v;
}
void dfs(int u, int fu)
{
dp[u][0][base] = 2;
dp[u][1][base] = 1;
low[u] = 0, high [u] = 0;
int size = (int)g[u].size();
for (int i = 0; i < size; ++ i) {
int v = g[u][i];
if (v == fu) continue;
dfs(v, u);
memset(t, 0, sizeof(t));
for (int x = low[u]; x <= high[u]; ++ x) {
for (int y = low[v] - 1; y <= high[v] + 1; ++ y) {
if (x + y > n || x + y < -n) continue;
add(t[0][x + y + base], make_mul(dp[u][0][x + base], dp[v][0][y + base], dp[u][0][x + base], dp[v][1][y - 1 + base], dp[u][0][x + base], dp[v][2][y + 1 + base]));
add(t[1][x + y + base], make_mul(dp[u][1][x + base], dp[v][0][y + base], dp[u][1][x + base], dp[v][2][y + base], dp[u][2][x + base], dp[v][1][y + base]));
add(t[2][x + y + base], make_mul(dp[u][2][x + base], dp[v][0][y + base], dp[u][1][x + base], dp[v][1][y + base], dp[u][2][x + base], dp[v][2][y + base]));
if (t[0][x + y + base] != 0) {
if (x + y < low[u]) low[u] = x + y;
if (x + y > high[u]) high[u] = x + y;
}
if (t[1][x + y + base] != 0) {
if (x + y < low[u]) low[u] = x + y;
if (x + y > high[u]) high[u] = x + y;
}
if (t[2][x + y + base] != 0) {
if (x + y < low[u]) low[u] = x + y;
if (x + y > high[u]) high[u] = x + y;
}
}
}
for (int j = low[u]; j <= high[u]; ++ j) {
dp[u][0][j + base] = t[0][j + base];
dp[u][1][j + base] = t[1][j + base];
dp[u][2][j + base] = t[2][j + base];
}
}
}
int main()
{
while (scanf("%d", &n) == 1) {
for (int i = 1; i <= n; ++ i) {
g[i].clear();
}
memset(dp, 0, sizeof(dp));
int a, b;
for (int i = 1; i < n; ++ i) {
scanf("%d%d", &a, &b);
g[a].push_back(b);
g[b].push_back(a);
}
dfs(1, 0);
int ans = 0;
for (int i = -1; i <= high[1]; ++ i) {
add(ans, 1ll * max(i, 0) * dp[1][0][i + base] % MOD);
add(ans, 1ll * max(i + 1, 0) * dp[1][1][i + base] % MOD);
add(ans, 1ll * max(i - 1, 0) * dp[1][2][i + base] % MOD);
}
ans = (int)(3ll * ans % MOD);
printf("%d\n", ans);
}
return 0;
}