题解
看的出来,相邻两层间的答案相互独立,只需要统计每一层最终的结果之和即可
统计的方法:树形dp
假设 v是u的儿子节点:
d
p
u
=
∑
max
(
1
,
d
p
v
−
(
d
e
p
t
h
v
−
d
e
p
t
h
u
)
)
dp_u=\sum\max(1,dp_v-(depth_v-depth_u))
dpu=∑max(1,dpv−(depthv−depthu))
如果v是叶子节点:
d
p
v
=
a
v
dp_v=a_v
dpv=av
总共 2 e 5 2e^5 2e5 个点,要是每一层都要统计到根节点的最终答案,每个点都要经过好几遍很明显会TLE
此时用虚树缩图
虚树入门 : 单调栈的应用 — 笛卡尔树与虚树
虚树板子
O ( n + ∑ k log k ) O(n+\sum k\log k) O(n+∑klogk)
vector<int> vt[N];//虚树
vector<int> p[N];//存放同一深度的节点
void link(int u, int v) { //u是v的父节点
vt[u].push_back(v);
vt[v].push_back(u);
}
void build(int depth) {
//单调栈 栈中保存的节点表示根节点到当前要处理的点的路径上面的点 保证栈里只有一颗树(不存在两颗同级的子树)
stack<int> st;
//p[depth]里记录深度相同的所有节点
//根据dfn排序
sort(p[depth].begin(), p[depth].end(),[](const int a, const int b) { return dfn[a] < dfn[b]; });
//先把根节点加入到栈里
st.push(root);
vt[root].clear();//清除之前建的虚树
int tmp;//临时变量 用于存储栈顶元素
for (int v:p[depth]) {//准备加入的点 v
vt[v].clear();
if (st.empty()) {
st.push(v);
continue;
}
int u = lca(v, st.top());//栈顶元素与v的最近公共祖先 u
// 此时u要么是栈顶元素 要么是栈顶元素的祖先
// 反正不可能是v 因为dfn[v]>dfn[栈顶元素] 树上dfs一遍就知道了
tmp = st.top();//取出栈顶元素 后面的st.top()就是栈顶元素下面的那个元素了
st.pop();
// 如果v是栈顶元素的子树里的节点 (不一定是儿子节点哦)
// 那么dep[st.top()] < dep[u=tmp] (dfn[tmp]>dfn[st.top()]
// 如果v不是栈顶元素子树里的节点
// 那么其有可能是栈顶元素的父辈里的某一个的子树里的节点 即dep[st.top()] >= dep[最近公共祖先]
// 那么v后面那个要加入的点p[depth][i+1] 也一定不是栈顶元素子树里的节点
// 所以栈顶元素就不需要了 此时需要退栈连边(约定所有的连边都在退栈时发生)
// 将栈顶元素与其下面那个元素建边 将栈顶元素退栈
while (!st.empty() && dep[st.top()] >= dep[u]) {
link(st.top(), tmp);//建边
tmp = st.top();
st.pop();
}
// 如果u=tmp(栈顶元素) 说明u本身就已经在栈里了 此时还不需要建边 再把tmp(u)放回去即可
// 如果u!=tmp
// 那么肯定存在那么一个关系 st.top()是u的父辈里的某一个节点 u是栈顶元素tmp的父辈里的某一个节点
if (u != tmp) {
vt[u].clear();//清除之前建的虚树 注意代码的位置 千万别把本次建的虚树已建好的边给删了
link(u, tmp);
st.push(u);
} else st.push(tmp);
st.push(v);// u->v是肯定的 暂时不需要建边 先把v节点塞进去
}
//最后 将栈里的所有节点连边即可
tmp = st.top();
st.pop();
while (!st.empty()) {
link(st.top(), tmp);
tmp = st.top();
st.pop();
}
}
当然 清图的部分也可以在建完图后dfs一遍删边
void clearVt(int u) { //清空虚树
for (int v:vt[u]) {
clearVt(v);
}
vt[u].clear();
}
代码
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N = 2e5 + 10;
vector<int> e[N];
ll a[N];
int n, root;
/*-----------------lca板子-----------------*/
int Log[N];//log2(x)
int dfn[N], cnt = 0;//各节点的编号
int fa[N][20];
int dep[N];//树的深度
void dfs(int u, int f) {
dfn[u] = ++cnt;
dep[u] = dep[f] + 1;
fa[u][0] = f;
for (int i = 1; (1 << i) <= n; ++i) {
fa[u][i] = fa[fa[u][i - 1]][i - 1];
}
for (int i = 0; i < e[u].size(); ++i) {
int v = e[u][i];
if (v != f)
dfs(v, u);
}
}
int lca(int u, int v) {
if (dep[u] < dep[v]) swap(u, v);
while (dep[u] > dep[v])
u = fa[u][Log[dep[u] - dep[v]]];
if (u == v) return v;
for (int i = Log[dep[u]]; i >= 0; --i) {
if (fa[u][i] != fa[v][i]) {
u = fa[u][i];
v = fa[v][i];
}
}
return fa[u][0];
}
/*------------建虚树-----------------*/
vector<int> vt[N];//虚树
vector<int> p[N];//存放同一深度的节点
void link(int u, int v) { //u是v的父节点
vt[u].push_back(v);
//vt[v].push_back(u); //这道题目里只需要知道儿子节点就行
}
void build(int depth) {
//单调栈 栈中保存的节点表示根节点到当前要处理的点的路径上面的点 保证栈里只有一颗树(不存在两颗同级的子树)
stack<int> st;
//p[depth]里记录深度相同的所有节点
//根据dfn排序
sort(p[depth].begin(), p[depth].end(),
[](const int a, const int b) { return dfn[a] < dfn[b]; });
//先把根节点加入到栈里
st.push(root);
vt[root].clear();//清除之前建的虚树
int tmp;//临时变量 用于存储栈顶元素
for (int v:p[depth]) {//准备加入的点 v
vt[v].clear();
if (st.empty()) {
st.push(v);
continue;
}
int u = lca(v, st.top());//栈顶元素与v的最近公共祖先 u
// 此时u要么是栈顶元素 要么是栈顶元素的祖先
// 反正不可能是v 因为dfn[v]>dfn[栈顶元素] 树上dfs一遍就知道了
tmp = st.top();//取出栈顶元素 后面的st.top()就是栈顶元素下面的那个元素了
st.pop();
// 如果v是栈顶元素的子树里的节点 (不一定是儿子节点哦)
// 那么dep[st.top()] < dep[u=tmp] (dfn[tmp]>dfn[st.top()]
// 如果v不是栈顶元素子树里的节点
// 那么其有可能是栈顶元素的父辈里的某一个的子树里的节点 即dep[st.top()] >= dep[最近公共祖先]
// 那么v后面那个要加入的点p[depth][i+1] 也一定不是栈顶元素子树里的节点
// 所以栈顶元素就不需要了 此时需要退栈连边(约定所有的连边都在退栈时发生)
// 将栈顶元素与其下面那个元素建边 将栈顶元素退栈
while (!st.empty() && dep[st.top()] >= dep[u]) {
link(st.top(), tmp);//建边
tmp = st.top();
st.pop();
}
// 如果u=tmp(栈顶元素) 说明u本身就已经在栈里了 此时还不需要建边 再把tmp(u)放回去即可
// 如果u!=tmp
// 那么肯定存在那么一个关系 st.top()是u的父辈里的某一个节点 u是栈顶元素tmp的父辈里的某一个节点
if (u != tmp) {
vt[u].clear();//清除之前建的虚树
link(u, tmp);
st.push(u);
} else st.push(tmp);
st.push(v);// u->v是肯定的 暂时不需要建边 把v节点塞进去
}
//最后 将栈里的所有节点连边即可
tmp = st.top();
st.pop();
while (!st.empty()) {
link(st.top(), tmp);
tmp = st.top();
st.pop();
}
}
ll dp[N];
//每一次dp 都是在统计同一深度节点之间产生的代价
void getdp(int u) {
dp[u] = 0;
if (vt[u].size() == 0) {
dp[u] = a[u];
return;
}
for (int i = 0; i < vt[u].size(); ++i) {
int v = vt[u][i];
getdp(v);
if (dp[v] != 0) //注意子节点一个也没有时 不应该被统计进去
dp[u] += max(1ll, dp[v] - (dep[v] - dep[u]));
}
}
int main() {
ios::sync_with_stdio(0);
cin >> n >> root;
for (int i = 1; i <= n; ++i) {
cin >> a[i];
}
for (int i = 1, u, v; i < n; ++i) {
cin >> u >> v;
e[u].push_back(v);
e[v].push_back(u);
}
//init the constant of lca
Log[0] = -1, Log[1] = 0;
for (int i = 2; i <= N; ++i) {
Log[i] = Log[i / 2] + 1;
}
dfs(root, 0);
for (int i = 1; i <= n; ++i) {
p[dep[i]].push_back(i);//把同一深度的节点放一起
}
ll ans = 0;
//特殊处理根节点
if (a[root] > 1) ans += (a[root] - 1);
else ans += a[root];
for (int i = 2; i <= n; ++i) {//枚举每一层
if (p[i].size() == 0) continue;
//深度相同的节点建虚树
build(i);
getdp(root);
if (dp[root] > 1) ans += dp[root] - 1;
else ans += dp[root];
}
cout << ans << endl;
return 0;
}