题目大意
NiroBC 姐姐是个活泼的少女,她十分喜欢爬树,而她家门口正好有一棵果树,正好满足了她爬树的需求。
这颗果树有N个节点,节点标号 1…N。每个节点长着一个果子,第i个节点上的果子颜色为 Ci 。
NiroBC姐姐每天都要爬树,每天都要选择一条有趣的路径 (u,v) 来爬。
一条路径被称作有趣的,当且仅当这条路径上的果子的颜色互不相同。
(u,v) 和 (v,u) 被视作同一条路径。特殊地,(i,i) 也被视作一条路径,这条路径只含 i 一个果子,显然是有趣的。
NiroBC姐姐想知道这颗树上有多少条有趣的路径。
数据范围
解题思路
考虑部分分怎么做。
一种很显然的思路,就是存下每个点往上走最多能走多远,设这个数组为up[]。
对于一条链的情况,就维护2个指针,用个set存一下就好了。
对于
n≤3000
n
≤
3000
,先
O(n2)
O
(
n
2
)
搞定以1为根的,然后
O(1)
O
(
1
)
换根,暴力重构up.
但是忘记set中的操作的复杂度带
log
l
o
g
。超时。
考虑直接枚举根,能走就走。。。。。。
AC?
有几个难点。
①颜色很多,只要满足其中一种限制条件就不合法。
②点对的路径可能包含了不止2个同样颜色的点。
③路径与路径之间的包含关系。(我认为这个地方卡住我了)
考虑任意一对相同颜色的点对
(u,v)
(
u
,
v
)
,若其中一个点是另外一个点的祖先,如图①。(情况A)
图①
找到
p∈son[u]
p
∈
s
o
n
[
u
]
且
p
p
是的祖先,如果点对
(a,b)
(
a
,
b
)
满足其中一个在
v
v
的子树,另一个不在的子树内,那么这个点对是不合法的。
如果
u
u
和谁也不是谁的祖先(情况B),那么如果点对
(a,b)
(
a
,
b
)
满足其中一个在
u
u
的子树,另一个在的子树内,那么这个点对是不合法的。
那么设计什么算法,能够快速判断出合法点对的个数?
即最后求出来的点对是满足所有相同颜色点对限制的?
考虑黑色代表不合法,白色代表合法。假设有一个
n∗n
n
∗
n
的网格,格子
(i,j)
(
i
,
j
)
可以看作
dfs
d
f
s
序为
i,j
i
,
j
的点对,他们是否合法。
如果满足情况A,那么可以将两个矩形染黑,如果满足情况B,那么可以将一个矩形染黑。(不要忘记每个矩形再作一个关于
y=x
y
=
x
对称的矩形!!!)
然后再用线段树+扫描线扫一遍答案。
矩形会有重叠?这意味着不能直接维护sum。
可以维护最小值
“0”
“
0
”
的个数!这就解决了问题了。
代码
#include<iostream>
#include<cstdio>
#include<cstring>
#include<cmath>
#include<algorithm>
#include<vector>
#define N 100010
#define LL long long
#define fo(i,a,b) for(i=a;i<=b;i++)
#define fd(i,a,b) for(i=a;i>=b;i--)
using namespace std;
struct notee{
int to,next;
};notee edge[N*2];
int tot,head[N];
struct note{
int sum,mi,la;
};note tr[N*4];
struct note1{
int x,l,r,z;
};note1 a[N*80],b[N*80];
int i,j,jj,k,l,n,m;
int u,v,Tot;
int opl,opr,opz;
int dfn[N],tar[N],dep[N],fa[N][17],T;
int col[N],c[N],cnt,wz,wz1;
int siz[N];
LL ans;
vector<int>v1[N];
int read(){
int fh=1,rs=0;char ch;
while((ch<'0'||ch>'9')&&(ch^'-'))ch=getchar();
if(ch=='-')fh=-1,ch=getchar();
while(ch>='0'&&ch<='9')rs=(rs<<3)+(rs<<1)+(ch^'0'),ch=getchar();
return fh*rs;
}
int Min(int x,int y){return x<y?x:y;}
void update(int ps){
tr[ps].mi=Min(tr[ps<<1].mi,tr[(ps<<1)|1].mi);
tr[ps].sum=(tr[ps<<1].mi==tr[ps].mi?tr[ps<<1].sum:0)+(tr[(ps<<1)|1].mi==tr[ps].mi?tr[(ps<<1)|1].sum:0);
}
void downld(int ps,int l,int r){
if(tr[ps].la){
int wz=(l+r)>>1;
tr[ps<<1].mi+=tr[ps].la;
tr[(ps<<1)|1].mi+=tr[ps].la;
tr[ps<<1].la+=tr[ps].la;
tr[(ps<<1)|1].la+=tr[ps].la;
tr[ps].la=0;
}
}
void change(int ps,int l,int r){
if(l>=opl&&r<=opr){
tr[ps].mi+=opz;
tr[ps].la+=opz;
return;
}
downld(ps,l,r);
int wz=(l+r)>>1;
if(opl<=wz)change(ps<<1,l,wz);
if(opr>wz)change((ps<<1)|1,wz+1,r);
update(ps);
}
void build(int ps,int l,int r){
if(l==r){
tr[ps].mi=0;
tr[ps].sum=1;
return;
}
int wz=(l+r)>>1;
build(ps<<1,l,wz);
build((ps<<1)|1,wz+1,r);
update(ps);
}
void lb(int x,int y){
edge[++tot].to=y;
edge[tot].next=head[x];
head[x]=tot;
}
void dg(int x){
dfn[x]=++T;tar[T]=x;
siz[x]=1;
int i;
for(i=head[x];i;i=edge[i].next)
if(fa[x][0]^edge[i].to){
dep[edge[i].to]=dep[x]+1;
fa[edge[i].to][0]=x;
dg(edge[i].to);
siz[x]+=siz[edge[i].to];
}
}
bool cmp(note1 x,note1 y){return x.x<y.x;}
int get(int x,int y){
int i;
fd(i,16,0)if(dep[fa[x][i]]>dep[y])x=fa[x][i];
return x;
}
void Add(int X1,int X2,int Y1,int Y2){
a[++Tot].x=X1;
a[Tot].l=Y1;a[Tot].r=Y2;
a[Tot].z=1;
b[Tot].x=Y1;
b[Tot].l=X1;b[Tot].r=X2;
b[Tot].z=1;
a[++Tot].x=X2+1;
a[Tot].l=Y1;a[Tot].r=Y2;
a[Tot].z=-1;
b[Tot].x=Y2+1;
b[Tot].l=X1;b[Tot].r=X2;
b[Tot].z=-1;
}
int main(){
freopen("tree.in","r",stdin);
freopen("tree.out","w",stdout);
n=read();
fo(i,1,n)col[i]=read(),c[i]=col[i];
sort(c+1,c+n+1);
cnt=unique(c+1,c+n+1)-c-1;
fo(i,1,n)col[i]=lower_bound(c+1,c+cnt+1,col[i])-c;
fo(i,1,n)v1[col[i]].push_back(i);
fo(i,1,n-1){
u=read(),v=read();
lb(u,v);lb(v,u);
}
dep[1]=1;
dg(1);
fo(j,1,16)fo(i,1,n)fa[i][j]=fa[fa[i][j-1]][j-1];
fo(i,1,cnt){
k=v1[i].size()-1;
fo(j,0,k-1)fo(jj,j+1,k){
u=v1[i][j];v=v1[i][jj];
if(dep[u]>dep[v])swap(u,v);
if(dfn[v]>=dfn[u]&&dfn[v]<=dfn[u]+siz[u]-1){
u=get(v,u);
if(dfn[u]>1)Add(1,dfn[u]-1,dfn[v],dfn[v]+siz[v]-1);
if(dfn[u]+siz[u]-1<n)Add(dfn[u]+siz[u],n,dfn[v],dfn[v]+siz[v]-1);
}else Add(dfn[u],dfn[u]+siz[u]-1,dfn[v],dfn[v]+siz[v]-1);
}
}
sort(a+1,a+Tot+1,cmp);
sort(b+1,b+Tot+1,cmp);
build(1,1,n);
wz=1;wz1=1;
fo(i,1,n){
for(;wz<=Tot&&a[wz].x<=i;wz++){
opl=a[wz].l;
opr=a[wz].r;
opz=a[wz].z;
change(1,1,n);
}
for(;wz1<=Tot&&b[wz1].x<=i;wz1++){
opl=b[wz1].l;
opr=b[wz1].r;
opz=b[wz1].z;
change(1,1,n);
}
if(!tr[1].mi)ans+=1ll*tr[1].sum;
}
ans=(ans-n)/2+n;
printf("%lld",ans);
return 0;
}