Description
Solution
平方串如何找是经典套路,假设要找长度为
2
L
2L
2L的平方串,则将序列分成每
L
L
L一段,最后不足
L
L
L也成一段。
假设当前到了第
i
i
i段,
i
i
i与
i
−
1
i-1
i−1的最长公共后缀为红线部分,
i
i
i与
i
+
1
i+1
i+1的最长公共前缀为蓝线部分,那么
[
l
,
r
]
[l,r]
[l,r]中所有长度等于
2
L
2L
2L的连续子序列都是平方串。考虑这些平方串,最后形成的是
∀
i
∈
[
l
,
r
−
L
]
\forall i\in[l,r-L]
∀i∈[l,r−L],
i
i
i与
i
+
L
i+L
i+L都有连边。
考虑kruskal的过程,将权值从小到大加入平方串,开
l
o
g
log
log个并查集表示,第
i
i
i个并查集的
x
x
x与
y
y
y连通表示
[
x
,
x
+
2
i
)
[x,x+2^i)
[x,x+2i)与
[
y
,
y
+
2
i
)
[y,y+2^i)
[y,y+2i)对应有连边。上述区间可以拆成两个RMQ区间的合并。
合并时,如果当前
i
i
i号并查集已经连通,就返回,否则将它们连通,递归处理
i
−
1
i-1
i−1号并查集,当到
0
0
0号并查集且没有联通时,答案就加上边权。
Code
#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<algorithm>
#define fo(i,j,k) for(int i=j,o=k;i<=o;++i)
#define fd(i,j,k) for(int i=j;i>=k;--i)
using namespace std;
typedef long long ll;
const int N=3e5+10;
const int mo=1e9+7;
int a[N],n;
ll z[N],nz[N];
ll s[N];
ll get(int l,int r){
return (s[r]-s[l-1]+mo)*nz[l-1]%mo;
}
int qpow(int x,int y){
int s=1;
for(;y;y>>=1,x=(ll)x*x%mo) if(y&1) s=(ll)s*x%mo;
return s;
}
struct node{
int x,y;
node(){}
node(int _x,int _y) {x=_x,y=_y;}
}w[N],b[N];
int tot=0;
bool cmp(node x,node y){
return x.x<y.x;
}
int lcp(node x,node y){
if(a[x.x]!=a[y.x]) return 0;
int l=1,r=min(x.y-x.x+1,y.y-y.x+1);
for(;l+1<r;){
int mid=(l+r)>>1;
get(x.x,x.x+mid-1)!=get(y.x,y.x+mid-1)?r=mid:l=mid;
}
if(get(x.x,x.x+r-1)==get(y.x,y.x+r-1)) l=r;
return l;
}
int lcs(node x,node y){
if(a[x.y]!=a[y.y]) return 0;
int l=1,r=min(x.y-x.x+1,y.y-y.x+1);
for(;l+1<r;){
int mid=(l+r)>>1;
get(x.y-mid+1,x.y)!=get(y.y-mid+1,y.y)?r=mid:l=mid;
}
if(get(x.y-r+1,x.y)==get(y.y-r+1,y.y)) l=r;
return l;
}
int f[19][N];
int find(int p,int x){
return x==f[p][x]?x:f[p][x]=find(p,f[p][x]);
}
int lg[N];
int L,val;
ll ans;
void merge(int x,int y,int t){
int u=find(t,x),v=find(t,y);
if(u==v) return;
f[t][v]=u;
if(!t) {ans+=val;return;}
merge(x,y,t-1),merge(x+(1<<(t-1)),y+(1<<(t-1)),t-1);
}
void ins(int l,int r){
if(r-l+1<2*L) return;
int t=lg[r-L-l+1];
merge(l,l+L,t),merge(r-L-(1<<t)+1,r-(1<<t)+1,t);
}
void solve(){
if(L==3){
int gg=1;
++gg;
}
tot=0;
fo(i,1,n/L)
b[++tot]=node((i-1)*L+1,i*L);
if(n%L) b[++tot]=node(n/L*L+1,n);
fo(i,2,tot){
int l=b[i].x-lcs(b[i],b[i-1]),r=b[i].y+(i<tot?lcp(b[i],b[i+1]):0);
ins(l,r);
}
}
int main()
{
freopen("endless.in","r",stdin);
freopen("endless.out","w",stdout);
nz[0]=z[0]=1;
fo(i,1,N-1) z[i]=z[i-1]*107%mo;
nz[1]=qpow(107,mo-2);
fo(i,2,N-1) nz[i]=nz[i-1]*nz[1]%mo;
fo(i,2,N-1) lg[i]=lg[i>>1]+1;
int T;
scanf("%d",&T);
for(;T--;){
scanf("%d",&n);
fo(i,1,n) scanf("%d",&a[i]),s[i]=(s[i-1]+a[i]*z[i])%mo;
fo(i,1,n/2) scanf("%d",&w[i].x),w[i].y=i;
fo(i,1,n)
fo(j,0,lg[n]) f[j][i]=i;
ans=0;
sort(w+1,w+n/2+1,cmp);
fo(i,1,n/2){
L=w[i].y,val=w[i].x;
solve();
}
printf("%lld\n",ans);
}
}