题目
https://gmoj.net/senior/#main/show/6807
题解
转化题意,可以发现,这道题就是选择一个根,使得它的某个子树内包含所有颜色,求满足条件的子树的最大深度。
比赛时我的思路是删掉以某个儿子为根的子树(或以当前点为根的子树外的部分),结果发现这样子处理不了删掉以孙子为根的子树的情况。最终挂在这道题上了。
其实应该考虑另外的处理方式:删掉某棵子树或删掉某棵子树外的全部点。这里的删掉指的是不选择这些点,若答案不等于1,根就在这些点中。
分别考虑这两种情况:
- 当我删掉一个子树时,说明这个子树外包含所有的颜色。但是这个条件比较难判断,于是考虑这个子树不能被删去时满足什么条件,显然是这个子树外缺少某种颜色,即这种颜色全都在这个子树内。对于每一种颜色,把它们的lca求出来,lca到根的路径上的点就是不能删掉子树的点,用倍增lca O ( n log 2 n ) O(n\log_2n) O(nlog2n)地处理即可(但是这样子常数巨大,优化后面会讲);
- 当我删掉一个子树外的所有点时,说明这个子树内包含所有点。发现树上处理起来很麻烦(可以线段树合并,但是空间和时间都可能爆掉),就把它转化到序列上(按dfn序排序)。双指针 O ( n ) O(n) O(n)地扫描一下就行了(当然你喜欢的话也可以用主席树,但是可能会炸空间)。
理论上这样打就能过了,但是我常数太大TLE了……
发现跑得最慢的部分是求一堆点的lca那里,要不开#pragma GCC optimize("O3")
过这题势必要优化这个部分。
这里有一个定理:
∀
d
f
n
a
≤
d
f
n
b
≤
d
f
n
c
,
都有
l
c
a
(
a
,
c
)
=
l
c
a
(
a
,
b
,
c
)
\forall dfn_a\le dfn_b\le dfn_c,都有 lca(a,c)=lca(a,b,c)
∀dfna≤dfnb≤dfnc,都有lca(a,c)=lca(a,b,c)。
证明的话就是
l
c
a
(
a
,
c
)
lca(a,c)
lca(a,c)的子树中必定包含了dfn在
[
a
,
c
]
[a,c]
[a,c]中所有点,因此必定也是b的祖先。
有了这个定理就可以将这个部分优化到 O ( m log 2 n ) O(m\log_2n) O(mlog2n)了,足以通过这道题。如果常数太大还是过不了,可以用tarjan lca。
CODE
倍增lca版本,常数稍大:
#include<cstdio>
using namespace std;
#define M 2000005
#define N 1000005
#define C 100005
struct array{int fir[C],nex[N];}a;bool cover[N];
int fir[N],to[M],nex[M],col[N],las[C],b[C],right[N];
int f[N][20],g[N][2],son[N],h[N],dep[N],dfn[N],id[N],siz[N],cnt,s,m;
inline char gc()
{
static char buf[100005],*l=buf,*r=buf;
return l==r&&(r=(l=buf)+fread(buf,1,100005,stdin),l==r)?EOF:*l++;
}
inline void read(int &x)
{
char ch;while(ch=gc(),ch<'0'||ch>'9');x=ch-48;
while(ch=gc(),ch>='0'&&ch<='9') x=x*10+ch-48;
}
inline void inc(int x,int y)
{
to[++s]=y,nex[s]=fir[x],fir[x]=s;
to[++s]=x,nex[s]=fir[y],fir[y]=s;
}
inline void swap(int &x,int &y){int z=x;x=y,y=z;}
void dfs(int k)
{
id[++cnt]=k,dfn[k]=cnt;
dep[k]=dep[f[k][0]]+1,siz[k]=1;
for(int i=fir[k];i;i=nex[i]) if(to[i]!=f[k][0])
f[to[i]][0]=k,dfs(to[i]),siz[k]+=siz[to[i]];
}
inline int mymax(int x,int y){return x>y?x:y;}
inline int getlca(int u,int v)
{
if(dep[u]<dep[v]) swap(u,v);
for(int i=19;i>=0;--i)
if(dep[f[u][0]]>=dep[v])
u=f[u][0];
if(u==v) return u;
for(int i=19;i>=0;--i)
if(f[u][i]^f[v][i])
u=f[u][i],v=f[v][i];
return f[u][0];
}
int main()
{
freopen("tree.in","r",stdin);
freopen("tree.out","w",stdout);
int n,x,y,l,r,tot=0,ans=0;
read(n),read(m);
for(int i=1;i<=n;++i) read(col[i]),a.nex[i]=a.fir[col[i]],a.fir[col[i]]=i;
for(int i=1;i<n;++i) read(x),read(y),inc(x,y);
dep[1]=1,dfs(1);
for(int i=n,tmp,k;i>1;--i)
{
tmp=g[id[i]][0]+1,k=f[id[i]][0];
if(tmp>g[k][0]) g[k][1]=g[k][0],g[k][0]=tmp,son[k]=id[i];
else if(tmp>g[k][1]) g[k][1]=tmp;
}
for(int i=2,k,fa;i<=n;++i) k=id[i],fa=f[k][0],h[k]=mymax(h[fa],g[fa][k==son[fa]])+1;
for(int j=1;j<20;++j)
for(int i=1;i<=n;++i)
f[i][j]=f[f[i][j-1]][j-1];
for(int i=1,max,min;i<=m;++i) if(a.fir[i])
{
max=0,min=N;
for(int j=a.fir[i];j;j=a.nex[j])
{
if(dfn[j]>max) max=dfn[j];
if(dfn[j]<min) min=dfn[j];
}
cover[getlca(id[max],id[min])]=1;
}
for(int i=n;i>1;--i) cover[f[id[i]][0]]|=cover[id[i]];
for(int i=1;i<=m;++i) las[i]=n+1;
for(int i=n;i;--i) right[i]=las[col[id[i]]],las[col[id[i]]]=i;
for(int i=1;i<=n;++i) if(!cover[id[i]]) ans=mymax(ans,g[id[i]][0]+1);
b[col[id[1]]]=1,tot=r=1;
while(r<=n&&tot<m)
{
if(!b[col[id[++r]]]) ++tot;
++b[col[id[r]]];
}
for(l=1;l<=n;++l)
{
if(r<=l+siz[id[l]]-1) ans=mymax(ans,h[id[l]]);
if(!--b[col[id[l]]])
{
if(right[l]>n) break;
for(int i=r+1;i<=right[l];++i) ++b[col[id[i]]];
r=right[l];
}
}
printf("%d\n",ans+1);
return 0;
}
tarjan lca版本,代码稍长:
#include<cstdio>
using namespace std;
#define M 2000005
#define N 1000005
#define C 100005
struct array{int fir[C],nex[N];}a;bool cover[N];
struct query
{
int fir[N],nex[200005],to[200005],s;
inline void inc(int x,int y)
{
to[++s]=y,nex[s]=fir[x],fir[x]=s;
to[++s]=x,nex[s]=fir[y],fir[y]=s;
}
}qry;
int fir[N],to[M],nex[M],col[N],las[C],b[C],right[N];
int f[N],g[N][2],son[N],h[N],dep[N],dfn[N],id[N],siz[N],fa[N],cnt,s,m;
inline char gc()
{
static char buf[100005],*l=buf,*r=buf;
return l==r&&(r=(l=buf)+fread(buf,1,100005,stdin),l==r)?EOF:*l++;
}
inline void read(int &x)
{
char ch;while(ch=gc(),ch<'0'||ch>'9');x=ch-48;
while(ch=gc(),ch>='0'&&ch<='9') x=x*10+ch-48;
}
inline void inc(int x,int y)
{
to[++s]=y,nex[s]=fir[x],fir[x]=s;
to[++s]=x,nex[s]=fir[y],fir[y]=s;
}
inline void swap(int &x,int &y){int z=x;x=y,y=z;}
void dfs(int k)
{
id[++cnt]=k,dfn[k]=cnt;
dep[k]=dep[fa[k]]+1,siz[k]=1;
for(int i=fir[k];i;i=nex[i]) if(to[i]!=fa[k])
fa[to[i]]=k,dfs(to[i]),siz[k]+=siz[to[i]];
}
inline int mymax(int x,int y){return x>y?x:y;}
int getf(int k){return f[k]==k?k:f[k]=getf(f[k]);}
void getlca(int k)
{
for(int i=fir[k];i;i=nex[i]) if(to[i]!=fa[k])
getlca(to[i]),f[to[i]]=k;
for(int i=qry.fir[k];i;i=qry.nex[i]) if(f[qry.to[i]]!=qry.to[i])
cover[getf(qry.to[i])]=1;
}
int main()
{
freopen("tree.in","r",stdin);
freopen("tree.out","w",stdout);
int n,x,y,l,r,tot=0,ans=0;
read(n),read(m);
for(int i=1;i<=n;++i) read(col[i]),a.nex[i]=a.fir[col[i]],a.fir[col[i]]=i;
for(int i=1;i<n;++i) read(x),read(y),inc(x,y);
dep[1]=1,dfs(1);
for(int i=n,tmp,k;i>1;--i)
{
tmp=g[id[i]][0]+1,k=fa[id[i]];
if(tmp>g[k][0]) g[k][1]=g[k][0],g[k][0]=tmp,son[k]=id[i];
else if(tmp>g[k][1]) g[k][1]=tmp;
}
for(int i=2,k;i<=n;++i) k=id[i],h[k]=mymax(h[fa[k]],g[fa[k]][k==son[fa[k]]])+1;
for(int i=1,max,min;i<=m;++i) if(a.fir[i])
{
max=0,min=N;
for(int j=a.fir[i];j;j=a.nex[j])
{
if(dfn[j]>max) max=dfn[j];
if(dfn[j]<min) min=dfn[j];
}
if(id[max]^id[min]) qry.inc(id[max],id[min]);
else cover[id[max]]=1;
}
for(int i=1;i<=n;++i) f[i]=i;getlca(1);
for(int i=n;i>1;--i) cover[fa[id[i]]]|=cover[id[i]];
for(int i=1;i<=m;++i) las[i]=n+1;
for(int i=n;i;--i) right[i]=las[col[id[i]]],las[col[id[i]]]=i;
for(int i=1;i<=n;++i) if(!cover[id[i]]) ans=mymax(ans,g[id[i]][0]+1);
b[col[id[1]]]=1,tot=r=1;
while(r<=n&&tot<m)
{
if(!b[col[id[++r]]]) ++tot;
++b[col[id[r]]];
}
for(l=1;l<=n;++l)
{
if(r<=l+siz[id[l]]-1) ans=mymax(ans,h[id[l]]);
if(!--b[col[id[l]]])
{
if(right[l]>n) break;
for(int i=r+1;i<=right[l];++i) ++b[col[id[i]]];
r=right[l];
}
}
printf("%d\n",ans+1);
return 0;
}