题目描述
给出一个有 $2^n$ 个叶子节点的完全二叉树。每个叶子节点可以选择黑白两种颜色。
对于每个非叶子节点左子树中的叶子节点 $i$ 和右子树中的叶子节点 $j$ :
如果 $i$ 和 $j$ 的颜色都为当前节点子树中颜色较多(相等视为白色)的那个,则不需要付出代价;
都为较小的那个则需要付 $2f[i][j]$ 的代价;
否则需要付 $f[i][j]$ 。
求最小代价。
输入
输出
你的程序只需要向输出文件输出一个整数,表示NS中学支付给网络公司的最小总费用。(单位:元)
样例输入
2
1 0 1 0
2 2 10 9
10 1 2
2 1
3
样例输出
8
题解
暴力+树形背包dp
先Orz一发CQzhangyu
首先观察付出代价的方式,可以改看作为:对于 $i$ ,选了颜色较少的那个则需要付出 $f[i][j]$ 的代价。这样我们就把两个点之间的代价转化为了单个点的代价。
然后由于这么多状态难以统计,因此需要提前计算代价。
设 $f[i][j]$ 表示点 $i$ 为根的子树内有 $j$ 个叶子节点选了黑色的最小总代价。那么这是一个树形背包问题,递归左右子树后跑背包合并即可。
然而这里有一个非常大的问题:代价的类型是与颜色较少还是较多有关的。
CQzhangyu给出的解法是:暴力枚举这两种情况,分别留下有用部分。由于是完全二叉树,因此有递归式:$T(1)=\log n,T(n)=4T(\frac n2)+O(n^2)$ ,不考虑 $T(1)$ 时根据主定理解得 $T(n)=O(n^2\log n)$ ,单独考虑 $T(1)$ ,1被考虑了 $n^2$ 次,因此也是 $O(n^2\log n)$ 。
因此直接暴力的复杂度是对的。
时间复杂度 $O(2^{2n}·n)$ 。
#include <cstdio>
#include <cstring>
#include <algorithm>
#define N 1030
#define ls x << 1
#define rs x << 1 | 1
using namespace std;
typedef long long ll;
int n , m , a[N] , t[N];
ll c[N] , w[N][N] , v[N][N] , f[N << 1][N];
void init(int x , int d)
{
if(!d) return;
int i , j;
init(ls , d - 1) , init(rs , d - 1);
for(i = x << d ; i < ((x << 1) + 1) << (d - 1) ; i ++ )
for(j = ((x << 1) + 1) << (d - 1) ; j < (x + 1) << d ; j ++ )
v[i - m][x] += w[i - m][j - m] , v[j - m][x] += w[i - m][j - m];
}
void dfs(int x , int d)
{
int i , j;
if(!d)
{
f[x][a[x - m]] = 0;
f[x][a[x - m] ^ 1] = c[x - m];
for(i = x >> 1 ; i ; i >>= 1) f[x][t[i]] += v[x - m][i];
return;
}
memset(f[x] , 0x3f , sizeof(ll) * ((1 << d) + 1));
t[x] = 1 , dfs(ls , d - 1) , dfs(rs , d - 1);
for(i = 0 ; i <= 1 << (d - 1) ; i ++ )
for(j = 0 ; j <= (1 << (d - 1)) - i ; j ++ )
f[x][i + j] = min(f[x][i + j] , f[ls][i] + f[rs][j]);
t[x] = 0 , dfs(ls , d - 1) , dfs(rs , d - 1);
for(i = 1 ; i <= 1 << (d - 1) ; i ++ )
for(j = (1 << (d - 1)) - i + 1 ; j <= 1 << (d - 1) ; j ++ )
f[x][i + j] = min(f[x][i + j] , f[ls][i] + f[rs][j]);
}
int main()
{
int i , j;
ll ans = 1ll << 62;
scanf("%d" , &n) , m = 1 << n;
for(i = 0 ; i < m ; i ++ ) scanf("%d" , &a[i]);
for(i = 0 ; i < m ; i ++ ) scanf("%lld" , &c[i]);
for(i = 0 ; i < m ; i ++ )
for(j = i + 1 ; j < m ; j ++ )
scanf("%lld" , &w[i][j]);
init(1 , n);
dfs(1 , n);
for(i = 0 ; i <= m ; i ++ ) ans = min(ans , f[1][i]);
printf("%lld\n" , ans);
return 0;
}