Description
NiroBC 姐姐是个活泼的少女,她十分喜欢爬树,而她家门口正好有一棵果树,正好满足了她爬树的需求。
这颗果树有N个节点,节点标号 1…N。每个节点长着一个果子,第i个节点上的果子颜色为 Ci 。
NiroBC姐姐每天都要爬树,每天都要选择一条有趣的路径 (u,v) 来爬。
一条路径被称作有趣的,当且仅当这条路径上的果子的颜色互不相同。
(u,v) 和 (v,u) 被视作同一条路径。特殊地,(i,i) 也被视作一条路径,这条路径只含 i 一个果子,显然是有趣的。
NiroBC姐姐想知道这颗树上有多少条有趣的路径。
Input
第一行,一个整数 N,表示果树的节点个数。
第二行,N 个整数 C1 ,C2 ,…,CN ,表示 N 个果子的颜色。
接下来 N−1 行,每行两个整数 ui ,vi ,表示 ui和vi 之间有一条边
数据保证这N−1条边构成一棵树。
Output
一个整数,表示有趣的路径的数量。
Sample Input
输入1:
3
1 2 3
1 2
1 3
输入2:
5
1 1 2 3 3
1 2
1 3
2 4
2 5
Sample Output
输出1:
6
样例解释:有 (1,1),(1,2),(1,3),(2,2),(2,3),(3,3) 共 6 条有趣的路径。
输出2:
8
样例解释:有 (1,1),(1,3),(2,2),(2,4),(2,5),(3,3),(4,4),(5,5) 共 8 条有趣的路径。
数据范围
题解
观察数据范围,也许会有提示,
同种颜色不会超过20次,
从这里得到了启发,
考虑一个颜色相同的点对能带了哪些限制,分两种情况:
1、两个点a,b存在祖先关系,就认为a是b的祖先,
设p为b到a的路径上面,距离a的距离为1 的点,
那么在b的子树里面的点,就无法到达p子树以外的点,
因为如果是要到达p子树以外的点,必然要跨过a点,就不满足题意。
2、如果两个点a,b不存在祖先关系,
那么a子树里面的点就不能到达b子树里面,
因为这样会同时跨过a,b两个颜色相同的点,不符合题意。
知道了每个限制,现在就要考虑如何求。
一个节点的子树在DFS序上面是连续的,
所以,以DFS序为横纵坐标建立坐标系,一个限制就变成了一个矩形,
现在问题就变为了:在一个二维平面里面,
有一些矩形,矩形可以重叠,
求整个二维平面上面没有被矩形覆盖的点的数量。
这就是一个扫描线的问题了。
code
#include <queue>
#include <cstdio>
#include <iostream>
#include <algorithm>
#include <cstring>
#include <string.h>
#include <cmath>
#include <math.h>
#include <time.h>
#define ll long long
#define N 100003
#define M 103
#define db double
#define P putchar
#define G getchar
#define inf 998244353
#define pi 3.1415926535897932384626433832795
using namespace std;
char ch;
void read(int &n)
{
n=0;
ch=G();
while((ch<'0' || ch>'9') && ch!='-')ch=G();
ll w=1;
if(ch=='-')w=-1,ch=G();
while('0'<=ch && ch<='9')n=(n<<3)+(n<<1)+ch-'0',ch=G();
n*=w;
}
int max(int a,int b){return a>b?a:b;}
int min(int a,int b){return a<b?a:b;}
ll abs(ll x){return x<0?-x:x;}
ll sqr(ll x){return x*x;}
void write(ll x){if(x>9) write(x/10);P(x%10+'0');}
struct node
{
int l,r,x,v;
}cz[100*N];
int n,c[N],t[N],x,y;
int nxt[N*2],to[N*2],lst[N],tot;
int dfn[N],id,f[17][N],dep[N],g[N],pos;
int opl,opr,ops,opx,op,lazy[4*N],s[N*4];
int nx[N],ls[N];
long long ans;
bool cmp(node a,node b)
{
return a.x<b.x;
}
void ins(int x,int y)
{
nxt[++tot]=lst[x];
to[tot]=y;
lst[x]=tot;
}
void dfs(int x,int fa)
{
dfn[x]=++id;f[0][x]=fa;dep[x]=dep[fa]+1;
for(int i=lst[x];i;i=nxt[i])
if(to[i]!=fa)dfs(to[i],x);
g[x]=id;
}
int lca(int x,int y)
{
if(dep[y]>dep[x])swap(x,y);
for(int i=16;i>=0;i--)
if(dep[f[i][x]]>=dep[y])x=f[i][x];
if(x==y)return x;
for(int i=16;i>=0;i--)
if(f[i][x]!=f[i][y])x=f[i][x],y=f[i][y];
return f[0][x];
}
void get(int x,int y)
{
if(dep[y]>dep[x])swap(x,y);
for(int i=16;i>=0;i--)
if(dep[f[i][x]]>dep[y])x=f[i][x];
pos=x;
}
void pre()
{
for(int j=1;j<17;j++)
for(int i=1;i<=n;i++)
f[j][i]=f[j-1][f[j-1][i]];
}
void find(int x,int l,int r)
{
if(lazy[x])return;
if(opl<=l && r<=opr)
{
ops+=s[x];
return;
}
int m=(l+r)>>1;
if(opl<=m)find(x<<1,l,m);
if(m<opr)find(x<<1|1,m+1,r);
}
void work(int x,int l,int r)
{
if(opl<=l && r<=opr)
{
if(opx==1)lazy[x]++,s[x]=0;
if(opx==2)
{
lazy[x]--;
if(lazy[x]==0)s[x]=(l==r?1:(lazy[x<<1]?0:s[x<<1])+(lazy[x<<1|1]?0:s[x<<1|1]));
}
return;
}
int m=(l+r)>>1;
if(opl<=m)work(x<<1,l,m);
if(m<opr)work(x<<1|1,m+1,r);
s[x]=(lazy[x<<1]?0:s[x<<1])+(lazy[x<<1|1]?0:s[x<<1|1]);
}
void build(int x,int l,int r)
{
s[x]=r-l+1;
if(l==r)return;
int m=(l+r)>>1;
build(x<<1,l,m);
build(x<<1|1,m+1,r);
}
void in(int x,int l,int r,int v)
{
if(l>r)return;tot++;
cz[tot].x=x;
cz[tot].l=l;
cz[tot].r=r;
cz[tot].v=v;
}
void jx(int x,int y,int xx,int yy)
{
if(x>xx)swap(x,xx);
if(y>yy)swap(y,yy);
if(xx>n || yy>n)return;
in(x,y,yy,1);
in(xx+1,y,yy,2);
}
int main()
{
freopen("tree.in","r",stdin);
freopen("tree.out","w",stdout);
read(n);
for(int i=1;i<=n;i++)
read(c[i]),nx[i]=ls[c[i]],ls[c[i]]=i;
for(int i=1;i<n;i++)
read(x),read(y),ins(x,y),ins(y,x);
dfs(1,0);build(1,1,n);pre();
tot=0;
for(int i=1;i<=n;i++)
for(int j=ls[i];j;j=nx[j])
for(int k=nx[j];k;k=nx[k])
{
x=lca(j,k);
if(x==j)
{
get(j,k);
jx(dfn[k],1,g[k],dfn[pos]-1);
jx(1,dfn[k],dfn[pos]-1,g[k]);
jx(dfn[k],g[pos]+1,g[k],n);
jx(g[pos]+1,dfn[k],n,g[k]);
}
else
if(x==k)
{
get(j,k);
jx(dfn[j],1,g[j],dfn[pos]-1);
jx(1,dfn[j],dfn[pos]-1,g[j]);
jx(dfn[j],g[pos]+1,g[j],n);
jx(g[pos]+1,dfn[j],n,g[j]);
}
else
{
jx(dfn[j],dfn[k],g[j],g[k]);
jx(dfn[k],dfn[j],g[k],g[j]);
}
}
sort(cz+1,cz+1+tot,cmp);
pos=1;ans=n;
for(int i=1;i<=n;i++)
{
for(;cz[pos].x==i;pos++)
{
opx=cz[pos].v;
opl=cz[pos].l;
opr=cz[pos].r;
work(1,1,n);
}
ans+=s[1];
}
printf("%lld\n",ans>>1);
return 0;
}