洛谷传送门
BZOJ传送门
题目描述
小凸和小方相约玩密室逃脱,这个密室是一棵有 n n n 个节点的完全二叉树,每个节点有一个灯泡。点亮所有灯泡即可逃出密室。每个灯泡有个权值 a i a_i ai,每条边也有个权值 b i b_i bi。点亮第一个灯泡不需要花费,之后每点亮一个新的灯泡 v v v 的花费,等于上一个被点亮的灯泡 u u u 到这个点 v v v 的距离 D u , v D_{u,v} Du,v,乘以这个点的权值 a v a_v av。在点灯的过程中,要保证任意时刻所有被点亮的灯泡必须连通,在点亮一个灯泡后必须先点亮其子树所有灯泡才能点亮其他灯泡。请告诉他们,逃出密室的最少花费是多少。
输入输出格式
输入格式:
第 1 1 1 行包含 1 1 1 个数 n n n,代表节点的个数
第 2 2 2 行包含 n n n 个数,代表每个节点的权值 a i a_i ai。 ( i = 1 , 2 , n ) (i = 1, 2, n) (i=1,2,n)
第 3 3 3 行包含 n − 1 n - 1 n−1 个数,代表每条边的权值 b i b_i bi,第 i i i 号边是由第 ( i + 1 ) / 2 (i+1)/2 (i+1)/2 号点连向第 i + 1 i + 1 i+1 号点的边。 ( i = 1 , 2 , n − 1 ) (i = 1, 2, n−1) (i=1,2,n−1)
输出格式:
输出包含 1 1 1 个数,代表最少的花费。
输入输出样例
输入样例#1:
3
5 1 2
2 1
输出样例#1:
5
说明
对于 10 % 10\% 10% 的数据, 1 ≤ n ≤ 10 1 \leq n \leq 10 1≤n≤10
对于 20 % 20\% 20% 的数据, 1 ≤ n ≤ 20 1 \leq n \leq 20 1≤n≤20
对于 30 % 30\% 30% 的数据, 1 ≤ n ≤ 2000 1 \leq n \leq 2000 1≤n≤2000
对于 100 % 100\% 100% 的数据, 1 ≤ n ≤ 2 ∗ 1 0 5 , 1 ≤ a i , b i ≤ 1 0 5 1 \leq n \leq 2 * 10^5, 1 \leq a_i, b_i \leq 10^5 1≤n≤2∗105,1≤ai,bi≤105
解题分析
如果我们当前点亮了 i i i号点的灯, 只会有三种选择:
-
如果 i i i号点是一个叶节点, 返回其某个祖先并到其另一个子树中继续开灯。
-
继续开子树中的灯
-
返回 1 1 1号点(最后一步)
所以可以从下往上 D P DP DP:设 f [ i ] [ k ] f[i][k] f[i][k]表示从 i i i到 i i i的 k k k级祖先并把这一个子树全部点亮的最小花费, g [ i ] [ k ] g[i][k] g[i][k]表示从 i i i到 i i i的 k k k级祖先的另一个儿子的最小花费, 然后就可以一波预处理出每个子树的代价。 最后枚举起点就可以一路往上跳统计答案了。
因为深度为 l o g ( N ) log(N) log(N), 总复杂度为 O ( N l o g ( N ) ) O(Nlog(N)) O(Nlog(N))。
代码如下:
#include <cstdio>
#include <cstring>
#include <cmath>
#include <cstdlib>
#include <cctype>
#include <algorithm>
#define R register
#define IN inline
#define W while
#define gc getchar()
#define MX 200500
#define ll long long
template <class T>
IN void in(T &x)
{
x = 0; R char c = gc;
for (; !isdigit(c); c = gc);
for (; isdigit(c); c = gc)
x = (x << 1) + (x << 3) + c - 48;
}
template <class T> IN T max(T a, T b) {return a > b ? a : b;}
template <class T> IN T min(T a, T b) {return a < b ? a : b;}
int n;
ll ans = 1e18;
int fat[MX][20], ls[MX], rs[MX];
ll f[MX][20], g[MX][20], dis[MX][20], w[MX];
//f: i号点已经点亮, 到达k级祖先的最优解
//g: i号点已经点亮, 到达k级祖先的另一个儿子的最优解
IN int oth(R int now, R int k) {return (now >> k - 1) ^ 1;}
int main(void)
{
in(n); R int i, j;
for (i = 1; i <= n; ++i) in(w[i]), fat[i][0] = i;
for (i = 2; i <= n; ++i) fat[i][1] = i >> 1, in(dis[i][1]);
for (i = 1; i <= n; ++i)
{
if ((i << 1) <= n) ls[i] = i << 1;
if ((i << 1 | 1) <= n) rs[i] = i << 1 | 1;
}
for (i = 2; i <= 18; ++i)
for (j = 2; j <= n; ++j)
{
fat[j][i] = fat[fat[j][i - 1]][1];
if (!fat[j][i]) continue;
dis[j][i] = dis[j][i - 1] + dis[fat[j][i - 1]][1];
}
for (i = n; i; --i)
{
if (!ls[i]) for (j = 1; fat[i][j - 1]; ++j)
g[i][j] = (dis[i][j] + dis[oth(i, j)][1]) * w[oth(i, j)];
else if (!rs[i]) for (j = 1; fat[i][j - 1]; ++j)
g[i][j] = dis[ls[i]][1] * w[ls[i]] + g[ls[i]][j + 1];
else for (j = 1; fat[i][j - 1]; ++j)
g[i][j] = min(dis[ls[i]][1] * w[ls[i]] + g[ls[i]][1] + g[rs[i]][j + 1], dis[rs[i]][1] * w[rs[i]] + g[rs[i]][1] + g[ls[i]][j + 1]);
}
for (i = n; i; --i)
{
if (!ls[i]) for (j = 1; fat[i][j - 1]; ++j)
f[i][j] = dis[i][j] * w[fat[i][j]];
else if (!rs[i]) for (j = 1; fat[i][j - 1]; ++j)
f[i][j] = dis[ls[i]][1] * w[ls[i]] + f[ls[i]][j + 1];
else for (j = 1; fat[i][j - 1]; ++j)
f[i][j] = min(dis[ls[i]][1] * w[ls[i]] + g[ls[i]][1] + f[rs[i]][j + 1], dis[rs[i]][1] * w[rs[i]] + g[rs[i]][1] + f[ls[i]][j + 1]);
}
for (i = 1; i <= n; ++i)
{
ll sum = f[i][1];
for (j = 1; fat[i][j]; ++j)
{
int tar = oth(i, j);
if (tar > n) sum += dis[fat[i][j]][1] * w[fat[i][j + 1]];
else sum += dis[tar][1] * w[tar] + f[tar][2];
}
ans = min(ans, sum);
}
printf("%lld\n", ans);
}