一个非常乱搞的题。
看一看题,发现合法的条件是两个很长的序列完全相等,看起来就很能乱搞了。我们考虑将序列看做一个多项式(允许有负次项),两个序列相等当且仅当它们对应的多项式相等,每次
+
+
+和
−
-
−操作就是给多项式加上或减去一个
x
k
x^k
xk。
设第
i
i
i次操作后的多项式为
P
i
P_i
Pi,那么枚举左端点
l
l
l,一个右端点合法的条件是
(
P
r
−
P
l
−
1
)
∗
x
−
d
l
−
1
=
P
n
(P_r-P_{l-1})*x^{-d_{l-1}}=P_n
(Pr−Pl−1)∗x−dl−1=Pn,其中
d
i
d_i
di是做完前
i
i
i个操作后指针的偏移量,移个项变成
P
n
∗
x
d
l
−
1
+
P
l
−
1
=
P
r
P_n*x^{d_{l-1}}+P^{l-1}=P_r
Pn∗xdl−1+Pl−1=Pr。
这里我们考虑对多项式随机代入若干个
x
x
x和模数后求值,这样可以把多项式映射到一个整数序列上,我们认为两个多项式相等当且仅当它们对应的证书序列相等。这样容易快速实现两个多项式相加,左移,右移,加上或减去一个
x
k
x^k
xk的操作。
那么只需要查询整数序列出现多少次即可,我拿了个map实现,时间复杂度
O
(
n
k
log
n
)
\mathcal O(nk\log n)
O(nklogn),
k
k
k是取的
x
x
x的个数,取
k
=
10
k=10
k=10就有非常高的正确率了。
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
ll pow_mod(ll x,int k,ll MOD) {
ll ans=1;
while (k) {
if (k&1) ans=ans*x%MOD;
x=x*x%MOD;
k>>=1;
}
return ans;
}
const int prime[10]={998244353,989390509,807943519,934784771,814294163,847763881,988501229,831312439,887199301,993514421};
const ll H[10]={19260817,760277961,602494845,909561729,536778743,682353612,738823132,631427472,848965391,984035289};
int powd[10][1000005],invd[10][1000005];
void pre(int n) {
for(int i=0;i<10;i++) {
ll inv=pow_mod(H[i],prime[i]-2,prime[i]);
assert(inv*H[i]%prime[i]==1);
powd[i][0]=1;
for(int j=1;j<=n;j++) powd[i][j]=(ll)powd[i][j-1]*H[i]%prime[i];
invd[i][0]=1;
for(int j=1;j<=n;j++) invd[i][j]=(ll)invd[i][j-1]*inv%prime[i];
}
}
struct Data {
int num[10];
Data() {memset(num,0,sizeof(num));}
Data operator + (Data b) {
Data c;
for(int i=0;i<10;i++) c.num[i]=(num[i]+b.num[i])%prime[i];
return c;
}
void add(int x) {
for(int i=0;i<10;i++) num[i]=(num[i]+powd[i][x])%prime[i];
}
void dec(int x) {
for(int i=0;i<10;i++) num[i]=(num[i]-powd[i][x]+prime[i])%prime[i];
}
void rshift(int x) {
for(int i=0;i<10;i++) num[i]=(ll)num[i]*powd[i][x]%prime[i];
}
void lshift(int x) {
for(int i=0;i<10;i++) num[i]=(ll)num[i]*invd[i][x]%prime[i];
}
bool operator < (const Data & b) const {
for(int i=0;i<10;i++)
if (num[i]!=b.num[i]) return num[i]<b.num[i];
return 0;
}
};
map <Data,int> mp;
Data p[250005];
char str[250005];
int main() {
int n;
scanf("%d%s",&n,str+1);
pre(4*n);
int d=0;
for(int i=1;i<=n;i++) {
p[i]=p[i-1];
if (str[i]=='+') p[i].add(2*n+d);
else if (str[i]=='-') p[i].dec(2*n+d);
else if (str[i]=='>') d++;
else d--;
mp[p[i]]++;
}
ll ans=0;
d=0;
for(int i=1;i<=n;i++) {
Data t=p[n];
if (d>=0) t.rshift(d); else t.lshift(-d);
t=t+p[i-1];
if (mp.count(t)) ans+=mp[t];
if (str[i]=='>') d++;
else if (str[i]=='<') d--;
mp[p[i]]--;
}
printf("%lld\n",ans);
return 0;
}