对于长度为x的平方串,只需要每隔x做一个关键点,然后对相邻关键点做lcp和lcs就可以找出每一个平方串。
sa或者hash+二分都是可以的
找出平方串的区间,下面要实现的操作就是实现区间中x和x+le合并。
用倍增维护一下并查集,开log个并查集,若x,y在第i个并查集里被并起来,意味着 x . . x + 2 i − 1 x..x+2^i-1 x..x+2i−1与 y . . y + 2 i − 1 y..y+2^i-1 y..y+2i−1是相等的。
每次并两个区间的时候,先拆成两个完整的2幂区间,然后合并一个2幂区间的时候,看他在当前并查集里是否并在一起,假如没有就先并起这层,然后分开两层,继续往下(有点像线段树那样)
每往下走两个点都会并一次,因此这个结构的复杂度是O(n log)
#include <cstdio>
#include <iostream>
#include <cstring>
#include <algorithm>
#include <vector>
using namespace std;
const int N = 1e6 + 10, W = 1e9 + 7;
typedef unsigned long long ll;
int n, T, a[N], fa[N];
int lg[N];
struct SA{
int sa[N], rk[N], h[N], cnt[N];
int nsa[N], nrk[N];
int le[N][20];
void reset() {
for(int i = 1; i <= n + n + n; i++) sa[i]=rk[i]=h[i]=0;
}
int lcp(int a,int b) {
a = rk[a], b = rk[b]; if (a > b) swap(a, b);
b--;
int sz = lg[b - a + 1];
return min(le[a][sz], le[b - (1 << sz) + 1][sz]);
}
void build(int *a) {
for(int i = 1; i <= n; i++) cnt[i] = 0;
for(int i = 1; i <= n; i++) rk[i] = a[i], cnt[a[i]] ++;
for(int i = 1; i <= n; i++) cnt[i] += cnt[i - 1];
for(int i = 1; i <= n; i++) sa[cnt[a[i]]--] = i;
for(int m = 1; m <= n; m <<= 1) {
for(int i = 1; i <= n; i++) cnt[i] = 0;
for(int i = 1; i <= n; i++) cnt[rk[i]]++;
for(int i = 1; i <= n; i++) cnt[i] += cnt[i - 1];
for(int i = n; i; i--) {
int x = sa[i] - m;
if (x > 0) nsa[cnt[rk[x]] --] = x;
}
for(int i = n - m + 1; i <= n; i++) nsa[cnt[rk[i]]--] = i;
for(int i = 1; i <= n; i++) sa[i] = nsa[i];
for(int i = 1; i <= n; i++) {
nrk[sa[i]] = nrk[sa[i - 1]] + (rk[sa[i - 1]] != rk[sa[i]] || rk[sa[i - 1] + m] != rk[sa[i] + m]);
}
for(int i = 1; i <= n; i++) rk[i] = nrk[i];
}
for(int i = 1; i <= n; i++) {
int nex = sa[rk[i] + 1];
h[rk[i]] = max(0, h[rk[i - 1]] - 1);
while (a[nex + h[rk[i]]] == a[i + h[rk[i]]])
h[rk[i]]++;
}
for(int i = 1; i <= n; i++) {
for(int j = 0; j < 20; j++) le[i][j] = -1e9;
}
for(int i = 1; i <= n; i++) le[i][0] = h[i];
for(int i = 1; i < 20; i++) {
for(int j = 1; j <= n; j++) {
le[j][i] = min(le[j][i - 1], le[j + (1 << i - 1)][i - 1]);
}
}
//check?
// for(int i = 1; i <= n; i++) {
// for(int j = sa[i]; j <= n; j++) printf("%d ",a[j]);
// printf("\n");
// }
// printf("\n");
}
} suf, pre;
int b[N];
int w;
pair<int,int> px[N];
struct dsu{
int fa[N];
int gf(int x) {
return fa[x] == 0 ? x : fa[x] = gf(fa[x]);
}
void merge(int a,int b) {fa[gf(a)] = gf(b);}
void reset() {
for(int i = 1; i <= n + n; i++) fa[i] = 0;
}
} d[20];
int qcnt;
void qmerge(int x,int y,int sz) {
if (d[sz].gf(x) != d[sz].gf(y)) {
d[sz].merge(x,y);
if (sz == 0) qcnt++;
else {
qmerge(x, y, sz - 1);
qmerge(x + (1 << sz - 1), y + (1 << sz - 1), sz - 1);
}
}
}
void qjmerge(int al,int ar,int bl,int br) {
int sz = lg[ar - al + 1];
qmerge(al, bl, sz);
qmerge(ar - (1 << sz) + 1, br - (1 << sz) + 1, sz);
}
int main(){
freopen("endless.in","r",stdin);
freopen("endless.out","w",stdout);
for(int i = 2; i <= 3e5; i++) lg[i] = lg[i >> 1] + 1;
int z = 0;
for(cin>>T;T;T--){
if ((++z) == 2318) {
int kkk = T;
}
scanf("%d",&n), w = n / 2;
for(int i = 0; i < 20; i++) d[i].reset();
for(int i = 1; i <= n; i++) scanf("%d",&a[i]), b[n - i + 1] = a[i];
for(int i = n + 1; i <= n + n; i++) a[i] = b[i] = 0;
for(int i = 1; i <= w; i++) {
scanf("%d", &px[i].first);
px[i].second = i;
}
suf.reset(), pre.reset();
pre.build(a);
suf.build(b);
sort(px + 1, px + 1 + w);
long long ans = 0;
for(int e = 1; e <= w; e++) {
int len = px[e].second;
qcnt = 0;
for(int ka = 1; ka + len <= n; ka += len) {
int kb = ka + len;
int la = min(len, suf.lcp(n - ka + 1, n - kb + 1));
int lb = min(len, pre.lcp(ka, kb));
if (la + lb < len + 1) continue;
int zb = ka - la + 1, yb = kb + lb - 1;
yb -= len;
qjmerge(zb, yb, zb + len, yb + len);
}
ans += (ll)qcnt * px[e].first;
}
printf("%lld\n",ans);
}
}