题目描述
L发明了一种与树有关的游戏(友情提醒:树是一个没有环的连通图):他从树中删除任意数量(可以为0)的边,计算删除后所有连通块大小的乘积,L将得到这么多的分数。你的任务就是对于一颗给定的树,求出L能得到的最大分数。
输入输出格式
输入格式:
第一行一个整数n,表示树的节点个数。
接下来n-1行,每行两个整数a[i],bi,表示a[i]与b[i]之间连边。
保证输入的图是一棵树。
输出格式:
输出一个整数,表示L能得到的最大分数。
输入输出样例
样例1: 5 1 2 2 3 3 4 4 5 样例2: 8 1 2 1 3 2 4 2 5 3 6 3 7 6 8 样例3: 3 1 2 1 3
样例1: 6 样例2: 18 样例3: 3
说明
【数据范围】
对于10%的数据,1<=n<=5;
对于30%的数据,1<=n<=100;
另有30%的数据,保证数据是一条链。
对于100%的数据,1<=n<=700;
对树形DP又有了更深一层的认识。。。。
方程还是很好想的
设f[u][j]为当前u这个点的子树内,分给儿子们j个点(也即自己留下siz[u] - j个作为一个联通块)的乘积
于是有f[u][j] = max{f[u][j],f[v][k] + f'[j - k]} (0 <= k <= j < siz[u]) (f'数组为之前的儿子所计算出的f[u])
特别的,对于f[u][siz[u]]有f[u][siz[u]] = f[u][j] * (siz[u] * j) (0 <= j < siz[u])
嗯。。这是蒟蒻的初步想法
飞速打完交了一遍以后。。。。发现他居然爆了long long。。。。。。。。。。
然后打高精。。。。。
接着T了无数遍。。。。。
然后看大爷的题解。。。。
发现大爷只是比我多了一个优化???????
于是引入了一个关于树形DP基本的复杂度的证明:
对于上述方程,我们修改一下,把j - k和k的位置对调一下,变成
f[u][j] = max{f[u][j],f[v][j - k] + f'[k]} (0 <= j < siz[u] , 0 <= k <= min(j , siz[u] - siz[v]))
显然他仍然是与原方程等价的,但是复杂度却完全不同
原方程的上限复杂度显然是O(n ^ 3)的(不计高精度),
而新方程的实质是对于一个当前u大小为j的子树中,去找不属于v的一棵大小为k的子树
那么这时单次DP的复杂度为O(siz[v] * (siz[u] - siz[v])),
这个优化可能不是很明显,但是我们这样考虑:由于两个子树是不相交的,那么也就可以看做是两个子树内所有的点两两求一遍lca,并且贡献只在lca处计算一次的复杂度
于是总的复杂度为O(n ^ 2)
代码:
#include<cstdio>
#include<cmath>
#include<queue>
#include<stack>
#include<vector>
#include<algorithm>
#include<cstring>
using namespace std;
typedef long long LL;
const int INF = 2147483647;
const int maxn = 710;
const int r = 1000000000;
struct data{
int tot;
LL m[20]; data(){memset(m,0,sizeof(m)); tot = 1;}
data operator * (data b) const
{
data ret;
for (int i = 1; i <= tot; i++)
{
for (int j = 1; j <= b.tot; j++)
{
ret.m[i + j] += (ret.m[i + j - 1] + m[i] * b.m[j]) / r;
ret.m[i + j - 1] = (ret.m[i + j - 1] + m[i] * b.m[j]) % r;
}
}
ret.tot = tot + b.tot - 1;
while (ret.m[ret.tot + 1])
{
ret.tot++;
ret.m[ret.tot + 1] += ret.m[ret.tot] / r;
ret.m[ret.tot] = ret.m[ret.tot] % r;
}
return ret;
}
data operator & (LL b) const
{
data ret;
for (int i = 1; i <= tot; i++)
{
ret.m[i + 1] += (ret.m[i + 1] + m[i] * b) / r;
ret.m[i] = (ret.m[i] + m[i] * b) % r;
}
ret.tot = tot;
while (ret.m[ret.tot + 1])
{
ret.tot++;
ret.m[ret.tot + 1] += ret.m[ret.tot] / r;
ret.m[ret.tot] = ret.m[ret.tot] % r;
}
return ret;
}
};
vector<int> e[maxn];
LL n,siz[maxn],s[maxn];
data f[maxn][maxn],ff[maxn],g[maxn];
inline LL getint()
{
LL ret = 0,f = 1;
char c = getchar();
while (c < '0' || c > '9')
{
if (c == '-') f = -1;
c = getchar();
}
while (c >= '0' && c <= '9')
ret = ret * 10 + c - '0',c = getchar();
return ret * f;
}
inline data hmax(data a,data b)
{
if (a.tot < b.tot) return b;
if (b.tot < a.tot) return a;
for (int i = a.tot; i >= 1; i--)
{
if (a.m[i] > b.m[i]) return a;
if (a.m[i] < b.m[i]) return b;
}
return a;
}
inline void dp(int u,int fa)
{
siz[u]++;
for (int i = 0; i < e[u].size(); i++)
{
int v = e[u][i];
if (v == fa) continue;
dp(v,u);
siz[u] += siz[v];
for (int j = 0; j <= siz[u]; j++) ff[j] = f[u][j];
for (int j = 0; j <= siz[u] - 1; j++)
for (int k = 0; k <= min(1ll * j,siz[u] - siz[v]); k++)
f[u][j] = hmax(f[u][j],f[v][j - k] * ff[k]);
}
for (LL j = 0; j <= siz[u] - 1; j++)
f[u][siz[u]] = hmax(f[u][siz[u]],f[u][j] & (siz[u] - j));
}
inline void init()
{
for (int i = 1; i <= n; i++)
f[i][0].m[1] = 1;
}
inline void print(data a)
{
int cnt,pos;
for (int i = 1; i <= a.tot; i++)
{
cnt = 0; pos = (i - 1) * 9;
while (a.m[i])
{
s[++pos] = a.m[i] % 10;
a.m[i] /= 10;
}
}
for (int i = pos; i >= 1; i--) printf("%d",s[i]);
}
int main()
{
n = getint();
for (int i = 1; i <= n - 1; i++)
{
int u = getint(),v = getint();
e[u].push_back(v); e[v].push_back(u);
}
init();
dp(1,0);
print(f[1][n]);
return 0;
}