神奇的树剖 + 利用一次询问多条路径 话不多说先放题面
Description
题目背景:
尊者神高达作为一个萌新,在升级路上死亡无数次后被一只大黄叽带回了师门。他加入师门后发现有无穷无尽的师兄弟姐妹,这几天新副本开了,尊者神高达的师门作为一个 pve师门,于是他们决定组织一起去开荒。
题目描述:
师门可以看做以 1 为根的一棵树,师门中的每一个人都有一定的装备分数。一共会有 q 个事件。每个事件可能是一次开荒,也可能是因为开荒出了好装备而导致一个人的装分出现了变化。对于一次开荒,会有 k 个人组织,由于师门的号召力很强,所以所有在组织者中任意两个人简单路径上的人都会参加。
Input
第一行 n ,q;
接下来 1 行 n 个数,代表每个人的分值;
接下来 n-1 行 u,v 代表一条边
接下来 q 行
Q 代表询问,接下来 k 个数代表组织的人数,读入为 0时停止读入。
C 代表修改,输入 x,w 代表将 x 的分值变为 w
Output
共 Q 的数量行,为开荒的人的总分值
Sample Input
4 4
10 5 2 2
1 2
2 3
2 4
Q 3 4 0
C 3 200
Q 3 4 0
Q 1 4 0
Sample Output
9
207
17
样例解释:
第一次询问,参加的人有 2,3,4 5+2+2=9
第一次修改,权值为 10 5 200 2
第二次询问,参加的人有 2,3,4 5+200+2=207
第三次询问,参加的人有 1,2,4 10+5+2=17
Data Constraint
数据范围:
20%的数据 n<=10000,q<=500;
另外 20%的数据 k=2
另外 20%的数据 没有修改操作
所有数据 n,q<=100000,所有询问 k 的和<=1000000
保证数据合法
这道题树剖很显然嘛 但我们一次要求多条路径
我的暴力想法—— laz 数组存多条路径走过的点 权值改为 1
为了剪枝 我把子节点有经过的点设为了 2 没有的就是 0 这样可以大大提高效率
找到 2 继续往下找 找到 1 直接返回线段树的值 找到 0 返回 0
于是就有了三十分 (预测没加优化只有 20 分 第三个点 700ms 卡过)
暴力代码
#include <algorithm>
#include <cstring>
#include <cstdio>
#define ll long long
using namespace std;
const int MAXN = 100010;
struct edge{int next,to;}e[MAXN << 1];
ll tr[MAXN << 3];
int siz[MAXN],dep[MAXN],fa[MAXN],son[MAXN],top[MAXN],id[MAXN],oid[MAXN];
int first[MAXN],v[MAXN],laz[MAXN << 3],tot;
inline int re() {
char q = getchar();
int x = 0;
while (q < '0' || q > '9') q = getchar();
while ('0' <= q && q <= '9')
x = (x << 3) + (x << 1) + q - (3 << 4),q = getchar();
return x;
}
inline void add(int x,int y) {
e[++tot].next = first[x];
e[tot].to = y;
first[x] = tot;
}
inline void dfs1(int p) {
dep[p] = dep[fa[p]] + 1;
++siz[p];
for (int a = first[p],b = e[a].to ; a ; a = e[a].next,b = e[a].to)
if (b == fa[p]) continue; else {
fa[b] = p;
dfs1(b);
siz[p] += siz[b];
if (siz[b] > siz[son[p]]) son[p] = b;
}
}
inline void dfs2(int p,int f) {
top[p] = f;
id[p] = ++tot;
oid[tot] = p;
if (!son[p]) return;
dfs2(son[p],f);
for (int a = first[p],b = e[a].to ; a ; a = e[a].next,b = e[a].to)
if (b != fa[p] && b != son[p]) dfs2(b,b);
}
inline void build(int l,int r,int len) {
if (l == r) {tr[len] = v[oid[l]]; return;}
int mid = (l + r) >> 1;
build(l,mid++,len << 1);
build(mid,r,len << 1 | 1);
tr[len] = tr[len << 1] + tr[len << 1 | 1];
}
inline void update(int l,int r,int len,int i,int j) {
if (l == r && l == i) {tr[len] = j; return;}
int mid = (l + r) >> 1;
if (i <= mid) update(l,mid,len << 1,i,j);
else update(++mid,r,len << 1 | 1,i,j);
tr[len] = tr[len << 1] + tr[len << 1 | 1];
}
inline void uplaze(int l,int r,int len,int i,int j) {
if (laz[len] & 1) return;
if (i <= l && r <= j) {laz[len] = 1; return;}
laz[len] = 2;
int mid = (l + r) >> 1;
if (i <= mid) uplaze(l,mid,len << 1,i,j);
if (mid < j) uplaze(++mid,r,len << 1 | 1,i,j);
}
inline ll rebuild(int l,int r,int len) {
if (!laz[len]) return 0;
if (laz[len] & 1) return tr[len];
int mid = (l + r) >> 1;
return rebuild(l,mid,len << 1) + rebuild(++mid,r,len << 1 | 1);
}
void lpcate(int x,int y) {
while (top[x] != top[y])
{
if (dep[top[x]] < dep[top[y]]) swap(x,y);
uplaze(1,siz[1],1,id[top[x]],id[x]);
x = fa[top[x]];
}
if (dep[x] > dep[y]) swap(x,y);
uplaze(1,siz[1],1,id[x],id[y]);
}
ll uplone(int l,int r,int len,int i) {
if (l == r && l == i) return tr[len];
int mid = (l + r) >> 1;
if (i <= mid) return uplone(l,mid,len << 1,i);
return uplone(++mid,r,len << 1 | 1,i);
}
int main()
{
freopen("kaihuang.in","r",stdin);
freopen("kaihuang.out","w",stdout);
int n = re(),q = re(),x,y;
for (int a = 1 ; a <= n ; ++ a) v[a] = re();
for (int a = 1 ; a < n ; ++ a) x = re(),y = re(),
add(x,y),add(y,x); tot = 0;
dfs1(1),dfs2(1,1);
build(1,n,1);
while (q--)
{
char w = getchar();
while (w != 'Q' && w != 'C') w = getchar();
if (w == 'Q')
{
memset(laz,0,sizeof(laz));
for (n = 1,v[n] = re() ; v[n] ; ++ n,v[n] = re());
for (int a = 2 ; a < n ; ++ a)
for (int b = 1 ; b < a ; ++ b) lpcate(v[b],v[a]);
if (n > 2) printf("%lld\n",rebuild(1,siz[1],1));
else printf("%lld\n",uplone(1,siz[1],1,id[v[1]]));
continue;
}
x = re(),y = re(),update(1,siz[1],1,id[x],y);
}
fclose(stdin);
fclose(stdout);
return 0;
}
好了好了做题怎么能不发正解呢
正解就是多条路径按 dfs 序 排然后求 lca 可使每个点经过两遍 两点 lca (可能除最顶上的) 经过三遍 这个手推可以发现 证明的话.....进点一次 跳到第二个点就两次 第 lca>该两点的点再跳就三次 每次减去一个 每个点就变两次了 最上面的次数就是 lca 为该点的节点对数 每次减去一个就为 0 了 所以要加上两个该点权
根据这个 我们可以直接求路径距离 然后减去相邻两点 lca 的权值 最后加上两个所有点最顶上的 lca (也是所有点的 lca) 的权值
其他部分不变 不过单点路径不用特判了(本来就不是那样做的).....话说我的暴力打的也是累
下放正解
#include <algorithm>
#include <cstring>
#include <cstdio>
#define ll long long
using namespace std;
const int MAXN = 100010;
struct edge{int next,to;}e[MAXN << 1];
ll tr[MAXN << 3],mx;
int siz[MAXN],dep[MAXN],fa[MAXN],son[MAXN],top[MAXN],id[MAXN],oid[MAXN];
int first[MAXN],v[MAXN],bot[MAXN],tot;
inline short cmp(int x,int y) {return id[x] < id[y];}
inline int r() {
char q = getchar(); int x = 0;
while (q < '0' || q > '9') q = getchar();
while ('0' <= q && q <= '9') x = (x << 3) + (x << 1) + q - (3 << 4),q = getchar();
return x;
}
inline void add(int x,int y) {
e[++tot].next = first[x];
e[tot].to = y;
first[x] = tot;
}
inline void dfs1(int p) {
dep[p] = dep[fa[p]] + 1;
++siz[p];
for (int a = first[p],b = e[a].to ; a ; a = e[a].next,b = e[a].to)
if (b == fa[p]) continue; else {
fa[b] = p;
dfs1(b);
siz[p] += siz[b];
if (siz[b] > siz[son[p]]) son[p] = b;
}
}
inline void dfs2(int p,int f) {
top[p] = f;
id[p] = ++tot;
oid[tot] = p;
if (!son[p]) return;
dfs2(son[p],f);
for (int a = first[p],b = e[a].to ; a ; a = e[a].next,b = e[a].to)
if (b != fa[p] && b != son[p]) dfs2(b,b);
}
inline void build(int l,int r,int len) {
if (l == r) {tr[len] = v[oid[l]]; return;}
int mid = (l + r) >> 1;
build(l,mid,len << 1);
build(++mid,r,len << 1 | 1);
tr[len] = tr[len << 1] + tr[len << 1 | 1];
}
inline void update(int l,int r,int len,int i,int j) {
if (l == r && l == i) {tr[len] = j; return;}
int mid = (l + r) >> 1;
if (i <= mid) update(l,mid,len << 1,i,j);
else update(++mid,r,len << 1 | 1,i,j);
tr[len] = tr[len << 1] + tr[len << 1 | 1];
}
inline ll get(int l,int r,int len,int i,int j) {
if (i <= l && r <= j) return tr[len];
int mid = (l + r) >> 1; ll ans = 0;
if (i <= mid) ans += get(l,mid,len << 1,i,j);
if (mid < j) ans += get(++mid,r,len << 1 | 1,i,j);
return ans;
}
inline ll lca(int x,int y) {
ll ans = 0;
while (top[x] != top[y])
{
if (dep[top[x]] < dep[top[y]]) swap(x,y);
ans += get(1,siz[1],1,id[top[x]],id[x]);
x = fa[top[x]];
}
if (dep[x] > dep[y]) swap(x,y);
return ans - v[x] + get(1,siz[1],1,id[x],id[y]);
}
inline int mxlca(int x,int y){
while (top[x] != top[y])
{
if (dep[top[x]] < dep[top[y]]) swap(x,y);
x = fa[top[x]];
}
return dep[x] < dep[y] ? x : y;
}
int main()
{
freopen("kaihuang.in","r",stdin);
freopen("kaihuang.out","w",stdout);
int n = r(),q = r(),x,y;
for (int a = 1 ; a <= n ; ++ a) v[a] = r();
for (int a = 1 ; a < n ; ++ a) x = r(),y = r(),
add(x,y),add(y,x); tot = 0;
dfs1(1),dfs2(1,1);
build(1,n,1);
while (q--)
{
char w = getchar();
while (w != 'Q' && w != 'C') w = getchar();
if (w == 'Q')
{
for (n = 1,bot[n] = r() ; bot[n] ; ++ n,bot[n] = r());
sort(bot + 1,bot + n,cmp);
ll ans = bot[1];
for (int a = 2 ; a < n ; ++ a) ans = mxlca(ans,bot[a]);
ans = v[ans] << 1;
for (int a = 2 ; a < n ; ++ a) ans += lca(bot[a - 1],bot[a]);
ans += lca(bot[1],bot[--n]);
printf("%lld\n",ans >> 1);
continue;
}
x = r(),y = r(),v[x] = y,update(1,siz[1],1,id[x],y);
}
fclose(stdin);
fclose(stdout);
return 0;
}