题意简述
一天,SY 终于搞到了 L u a \rm Lua Lua 的 U 盘!
他看见里面有个叫 “ 绅士 ” 的压缩包,毫不犹豫地点开看,发现压缩包设了密码,SY 只能看到里面数十万的文件夹,却看不到一个文件,只有根目录下的一个 txt。SY 欣喜若狂,因为哪怕是 L u a \rm Lua Lua 的一个 txt 都是宝藏!他点开看,结果令它哭笑不得,是一道题,但是给出了密码的线索:
- 该压缩包里共有 N N N 个文件夹,彼此之间的包含关系形成了一棵树,根不重要。每个文件夹开头都有个数字,表示该文件夹里的主色调颜色编号。
- 求不存在相同颜色结点的路径条数, ( u , v ) (u,v) (u,v) 和 ( v , u ) (v,u) (v,u) 被视作一条路径, ( u , u ) (u,u) (u,u) 也是一种合法路径。
- 同种颜色的文件夹不超过 20 个。
- N ≤ 1 e 5 , 1 ≤ C o l o r i ≤ N N\leq 1e5\;,\;1\leq Color_i\leq N N≤1e5,1≤Colori≤N.
最终得到的答案就是该压缩包的密码!SY 的编译出了问题,只好请你来帮忙做这个解密游戏。
题解
这个转换太妙了,必须记下来:
考虑这样的一个问题:
给定一棵 N 个点的树,给出 M 组限制,每个限制形如“ai 与 bi 不可同时出现在路径上”,问有多少条合法路径。
如果一条路径 (u, v) 因为第 i 组限制而不合法,就说明路径 (u, v) 上同时出现了 ai 与 bi 两个点。
如果 ai 是 bi 的祖先,设 p 是 ai 到 bi 的路径上与 ai 距离为 1 的点。此时,u 和 v 两个点,其中一个不在 p 的子树内,另一个在 bi 的子树内。
如果 ai 和 bi 谁也不是谁的祖先,那么 u 和 v 两个点,一个在 ai 的子树内,另一个在 bi 的子树内。
如果把一条路径 (u, v) 映射成二维平面上的某一个点,横坐标是 u 的 DFS序,纵坐标是 v 的 DFS 序,那么,一组限制让平面上的一个(谁也不是谁的祖先时)或两个(ai 是 bi 的祖先时)矩形内的所有路径不合法了。
这时,问题转换成了:有一个 N × N 的二维平面,上面有 O(M) 个矩形,问有多少个格子,不在任何一个矩形内部。这个问题可以用线段树 + 扫描线在O(M log N) 的时间复杂度内解决。
嗯,回到这道题目。要求路径上的点颜色互不相同,所以同色的点对不能同时出现在路径上。由于同一种颜色出现次数不超过 20 次,所以 M = O(20N),那么,整道题的时间复杂度就是 O(20N log N)。
这里笔者介绍一种好用的线段树方案:维护区间最小值以及该最小值的数量,每次加矩阵就把一段区间加 1 就完了,询问相当于问这个区间最小值是不是 0 ,如果是就加上最小值数量。懒标记甚至可以永久化。
CODE
#include<set>
#include<map>
#include<queue>
#include<vector>
#include<cstdio>
#include<cstring>
#include<iostream>
#include<algorithm>
using namespace std;
#define MAXN 200005
#define DB double
#define LL long long
#define ENDL putchar('\n')
#define lowbit(x) (-(x) & (x))
#define FI first
#define SE second
#define makepair(x,y) (pair<int,int>){(x),(y)}
LL read() {
LL f = 1,x = 0;char s = getchar();
while(s < '0' || s > '9') {if(s=='-')f = -f;s = getchar();}
while(s >= '0' && s <= '9') {x=x*10+(s-'0');s = getchar();}
return f * x;
}
const int MOD = 998244353;
int n,m,i,j,s,o,k;
int Abs(int x) {return x < 0 ? -x:x;}
struct it{
LL nm,ct; it(){nm=1e17;ct=0;}
it(LL N,LL C){nm=N;ct=C;}
}tre[MAXN<<3];
it bing(it a,it b) {
if(a.nm != b.nm) return a.nm < b.nm ? a:b;
return it(a.nm,a.ct + b.ct);
}
it operator + (it a,LL b) {a.nm += b;return a;}
LL lz[MAXN<<3];int M;
void maketree(int n) {M=1;while(M<n+2)M<<=1;}
void addp(int x,it y) {
int s = M+x;tre[s] = y; s >>= 1;
while(s) tre[s] = bing(tre[s<<1],tre[s<<1|1])+lz[s],s >>= 1;
return ;
}
void addtree(int l,int r,LL ad) {
if(l > r) return ;
int s = M+l-1,t = M+r+1;
while(s || t) {
if(s<M) tre[s] = bing(tre[s<<1],tre[s<<1|1])+lz[s];
if(t<M) tre[t] = bing(tre[t<<1],tre[t<<1|1])+lz[t];
if((s>>1) ^ (t>>1)) {
if(!(s&1)) tre[s^1] = tre[s^1] + ad,lz[s^1] += ad;
if(t & 1) tre[t^1] = tre[t^1] + ad,lz[t^1] += ad;
}
s >>= 1;t >>= 1;
}
return ;
}
it findtree(int l,int r) {
if(l > r) return it();
int s = M+l-1,t = M+r+1;
it ls = it(),rs = it();
while(s || t) {
if(s < M) ls = ls + lz[s];
if(t < M) rs = rs + lz[t];
if((s>>1) ^ (t>>1)) {
if(!(s&1)) ls = bing(ls,tre[s^1]);
if(t & 1) rs = bing(rs,tre[t^1]);
}
s >>= 1;t >>= 1;
}return bing(ls,rs);
}
int cl[MAXN];
vector<int> bu[MAXN];
vector<int> g[MAXN];
int d[MAXN],f[MAXN][20],dfn[MAXN],rr[MAXN],tim;
void dfs(int x,int ff) {
d[x] = d[f[x][0] = ff] + 1;
for(int i = 1;i <= 17;i ++) f[x][i] = f[f[x][i-1]][i-1];
dfn[x] = ++ tim;
for(int i = 0;i < (int)g[x].size();i ++) {
int y = g[x][i];
if(y != ff) {
dfs(y,x);
}
}
rr[x] = tim;
return ;
}
vector<pair<int,int> > c[MAXN];
void addmat(int l,int r,int d,int u) {
c[l].push_back(makepair(d,u));
c[r+1].push_back(makepair(-d,-u));
return ;
}
int main() {
freopen("tree.in","r",stdin);
freopen("tree.out","w",stdout);
n = read();
for(int i = 1;i <= n;i ++) {
cl[i] = read();
}
for(int i = 1;i < n;i ++) {
s = read();o = read();
g[s].push_back(o);
g[o].push_back(s);
}
dfs(1,0);
maketree(tim);
for(int i = 1;i <= n;i ++) {
addp(i,it(0,1));
int co = cl[i];
for(int j = 0;j < (int)bu[co].size();j ++) {
int y = bu[co][j],x = i;
if(dfn[x] < dfn[y]) swap(x,y);
if(dfn[y] <= dfn[x] && rr[y] >= dfn[x]) {
int p = x;
for(int k = 17;k >= 0;k --) {
if(d[f[p][k]] > d[y]) p = f[p][k];
}
addmat(1,dfn[p]-1,dfn[x],rr[x]);
addmat(dfn[x],rr[x],rr[p]+1,n);
}
else {
addmat(dfn[y],rr[y],dfn[x],rr[x]);
}
}
bu[co].push_back(i);
}
LL ans = 0;
for(int i = 1;i <= n;i ++) {
for(int j = 0;j < (int)c[i].size();j ++) {
int ll = c[i][j].FI,rr = c[i][j].SE;
if(ll > 0) addtree(ll,rr,1);
else addtree(-ll,-rr,-1);
}
it as = findtree(i,n);
if(as.nm == 0) ans += as.ct;
}
printf("%lld\n",ans);
return 0;
}