题目
一棵树,每个节点上是左括号或者右括号,定义S(x,y)为树上从x走到y,经过点上符号所形成的字符串。
定义f(x,y)表示对S(x,y)进行划分,最多能划分成多少个连续的合法括号序列
每次询问有多少对点满足f(x,y)=k(k>0)
题解
我们先考虑这样一个问题,对于已知的一个括号序列,怎么求f
那么首先我们当然是要看它是不是本身是一个合法的括号序列(也即是说,前缀和没有<0的)
然后,我们让(=1,)=-1,求个前缀和,那么f就是这个前缀和中0出现的次数
在此基础上,我们又可以拓展出一种方法,定义a[s][i]表示前缀和为s,s出现了i次的方案数(注意还要保证没有出现过大于s的前缀和,不然就成功地重复了),b[s][j]为前缀和为-s
那么可有
a
n
s
[
k
]
+
=
∑
s
∑
i
+
j
−
1
=
k
a
[
s
]
[
i
]
⋅
b
[
s
]
[
j
]
ans[k]+=\sum_s\sum_{i+j-1=k}a[s][i]\cdot b[s][j]
ans[k]+=∑s∑i+j−1=ka[s][i]⋅b[s][j]
但是这个其实并不完全对,s=0的时候是要特判的(因为不会有一个新的贡献)
不过大致上是一个卷积的形式了(可以通过b位移一位但是b[0]不动来造出真正的卷积)
然后对于这种所有路径统计的题,我们上点分治套FFT
具体一点,我们在一次分治中,要做以下工作:
从该点出发进行一次dfsa得到有关的a数组信息
从每个邻接点出发进行一次dfsb得到有关的b数组信息
代码中用的邻接表,其实vector应该也可以
然后把a和b卷起来
聪明的读者一定已经发现了,这样会出现一种重复
如图中红色的路径被计算进答案了,但这显然不能被算进答案中,所以我们还应该对每个邻接点进行计算减去这种路径的贡献
似乎就没啥别的要注意的了…注意清零
然后是复杂度分析。
每次FFT时的次数是这个前缀和出现的次数,显然,这个次数每增加1,括号序列长度就会消耗2
也就是说,对于一个点处理时的FFT的最高次数之和是O(N)级别的
也就是说FFT的复杂度之和是O(NlogN)级别的
所以
T
(
n
)
=
2
T
(
n
2
)
+
O
(
n
l
o
g
n
)
T(n)=2T(\frac{n}{2})+O(nlogn)
T(n)=2T(2n)+O(nlogn)
所以
T
(
n
)
=
O
(
n
l
o
g
2
n
)
T(n)=O(nlog^2n)
T(n)=O(nlog2n)
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<cmath>
using namespace std;
typedef long long ll;
struct Complex{
double r,i;
Complex(){}
Complex(double _r,double _i):r(_r),i(_i){}
};
Complex operator +(Complex a,Complex b){
return Complex(a.r+b.r,a.i+b.i);
}
Complex operator -(Complex a,Complex b){
return Complex(a.r-b.r,a.i-b.i);
}
Complex operator *(Complex a,Complex b){
return Complex(a.r*b.r-a.i*b.i,a.r*b.i+a.i*b.r);
}
const int N=500005;
const int INF=0x3f3f3f3f;
const double pi=acos(-1.0);
void FFT(Complex *p,int n,int dir){
for(int i=1,j=0;i<n-1;i++){
for(int s=n;j^=s>>=1,~j&s;);
if(i<j)
swap(p[i],p[j]);
}
for(int m=1;m<n;m<<=1){
int m2=m<<1;
double r=pi/m*dir;
Complex wn(cos(r),sin(r));
for(int i=0;i<n;i+=m2){
Complex w(1,0);
for(int j=0;j<m;j++){
Complex t1=p[i+j],t2=p[i+j+m];
p[i+j]=t1+w*t2;
p[i+j+m]=t1-w*t2;
w=w*wn;
}
}
}
if(dir==-1){
for(int i=0;i<n;i++)
p[i].r/=n;
}
}
int res[N];
Complex A[N],B[N];
void Mul(int n){
int len=1;
while(len<2*n)
len<<=1;
for(int j=n;j<len;j++)
A[j]=B[j]=Complex(0,0);
FFT(A,len,1);
FFT(B,len,1);
for(int i=0;i<len;i++)
A[i]=A[i]*B[i];
FFT(A,len,-1);
for(int i=0;i<len;i++)//!
res[i]=(int)(A[i].r+0.5);
}
struct node{
int u,v,nxt;
}edge[N*2],data[N*2];
int head[N],mcnt=1;
void add_edge(int u,int v){
mcnt++;
edge[mcnt].u=u;
edge[mcnt].v=v;
edge[mcnt].nxt=head[u];
head[u]=mcnt;
}
int dcnt;
void add_data(int &head,int val){
dcnt++;
data[dcnt].v=val;
data[dcnt].nxt=head;
head=dcnt;
}
int sz[N],totsz,f[N];
int root;
bool vis[N];
void Findroot(int u,int fa){
sz[u]=1;
f[u]=0;
for(int i=head[u];i;i=edge[i].nxt){
int v=edge[i].v;
if(!vis[i]&&v!=fa){
Findroot(v,u);
sz[u]+=sz[v];
f[u]=max(f[u],sz[v]);
}
}
f[u]=max(f[u],totsz-sz[u]);
if(f[u]<f[root])
root=u;
}
int val[N];
int n;
int ans[N];
int md;
int h[N];
int ga[N],gb[N];
void dfsa(int u,int fa,int s,int mx,int cnt){
s+=val[u];
if(mx<s)
mx=s,cnt=1;
else if(mx==s)
cnt++;
if(mx==s&&s>=0){
md=max(md,s);
h[s]=max(h[s],cnt);
add_data(ga[s],cnt);
}
for(int i=head[u];i;i=edge[i].nxt){
int v=edge[i].v;
if(vis[i]||v==fa)
continue ;
dfsa(v,u,s,mx,cnt);
}
}
void dfsb(int u,int fa,int s,int mn,int cnt){
s+=val[u];
if(mn>s)
mn=s,cnt=1;
else if(mn==s)
cnt++;
if(mn==s&&s<=0){
md=max(md,-s);
h[-s]=max(h[-s],cnt);
if(s==0)
add_data(gb[-s],cnt);//!
else
add_data(gb[-s],cnt-1);
}
for(int i=head[u];i;i=edge[i].nxt){
int v=edge[i].v;
if(vis[i]||v==fa)
continue ;
dfsb(v,u,s,mn,cnt);
}
}
void calc(int dir){
for(int i=0;i<=md;i++){
if(ga[i]&&gb[i]){
int m=h[i];
for(int j=0;j<=2*h[i];j++){
A[j]=B[j]=Complex(0.0,0.0);
//aa[j]=bb[j]=0;
}
for(int j=ga[i];j;j=data[j].nxt){
int v=data[j].v;
A[v].r=A[v].r+1;
//aa[v]++;
}
for(int j=gb[i];j;j=data[j].nxt){
int v=data[j].v;
B[v].r=B[v].r+1;
//bb[v]++;
}
Mul(2*m);
for(int j=0;j<=2*m;j++)
ans[j]+=dir*res[j];
}
ga[i]=gb[i]=h[i]=0;
}
}
void Solve(int u){
dcnt=0,md=0;
dfsb(u,0,0,INF,0);
for(int i=head[u];i;i=edge[i].nxt){
int v=edge[i].v;
if(vis[i])
continue ;
dfsa(v,u,0,-INF,0);
}
for(int i=gb[0];i;i=data[i].nxt)
ans[data[i].v]++;
calc(1);
for(int i=head[u];i;i=edge[i].nxt){
int v=edge[i].v;
if(vis[i])
continue ;
dcnt=0;
md=0;
dfsa(v,u,0,-INF,0);
dfsb(v,u,val[u],val[u],1);//!
calc(-1);
}
for(int i=head[u];i;i=edge[i].nxt){
int v=edge[i].v;
if(vis[i])
continue ;
vis[i]=vis[i^1]=true;
f[0]=totsz=sz[v];
root=0;
Findroot(v,u);
Solve(root);
}
}
int main()
{
scanf("%d",&n);
for(int i=1;i<n;i++){
int u,v;
scanf("%d%d",&u,&v);
add_edge(u,v);
add_edge(v,u);
}
for(int i=1;i<=n;i++){
char s[5];
scanf("%s",s);
val[i]=s[0]=='('?1:-1;
}
f[0]=totsz=n;
root=0;
Findroot(1,0);
Solve(root);
int m;
scanf("%d",&m);
for(int i=1;i<=m;i++){
int p;
scanf("%d",&p);
printf("%d\n",ans[p]);
}
}