题目描述:
你有一个长度为nnn的字符串,其中仅含'0','1','2'三个字符。你希望知道,这个字符串有多少个子串,满足该子串的'0','1','2'个数相等。
算法:哈希+前缀和算法。
'0'--->1, '1'--->3000010(要取大一点!防止冲突), '2'---> -(1+3000010);
使得'0'+'1'+'2'==0,可以说考虑的是三者相对关系,这样会比较好统计。
#include<bits/stdc++.h>
using namespace std;
const int N=3000010;
typedef long long LL;
typedef pair<int,int>PII;
typedef pair<PII,int>PIII;
#define ios ios::sync_with_stdio(false),cin.tie(0),cout.tie(0);
#define endl '\n'
int cnt[5];
int main()
{
ios;
int T;
cin>>T;
while(T--)
{
int n;
string str;
cin>>n>>str;
unordered_map<LL,int>s;
LL sum=0,res=0;
s[0]=1;
int a=1,b=N,c=-(a+b);
for(int i=0;i<str.size();i++)
{
if(str[i]=='0') sum+=a;
else if(str[i]=='1') sum+=b;
else sum+=c;
res+=s[sum];
s[sum]++;
}
cout<<res<<endl;
}
return 0;
}
比赛时我的构造是:'0'--->1, '1'--->N '2'--->N*N
这样其实考虑的是三个字符各自的个数,于是我取了个余hh,没想到过了,相当于减去k个('0'+'1'+'2')的值
主要代码如下:
LL a=1,b=N,c=(LL)N*N,d=a+b+c;
for(int i=0;i<str.size();i++)
{
if(str[i]=='0') sum+=a;
else if(str[i]=='1') sum+=b;
else sum+=c;
LL k=sum%d;
res+=s[k];
s[k]++;
}
还看到了其他的处理方法,map里套pair,pair中两个值记录了cnt['0']-cnt['2'],cnt['1']-cnt['2'],就是固定’2‘,考虑其他字符相对'2'的个数。