hdu 7055 Yiwen with Sqc
本篇题解非正解(超时了),正解传送门
题意:
字符串的子串当中每个字母的出现个数(用 s [ ′ a ′ ] [ l , r ] s['a'] [l, r] s[′a′][l,r] 来表示),要求的就是每个字母的所有 s [ l , r ] s[l, r] s[l,r] 的平方和
分析:
要求所有的子串, C n 2 C_{n}^2 Cn2 ,暴力的话是 O ( n 2 ) O(n^2) O(n2) 的算法,优化 O ( n 2 ) O(n^2) O(n2) 的算法我最近刚学了一个 F F T FFT FFT (还热乎着呢…),那么此题能否用 F F T FFT FFT 优化呢?
众所周知, F F T FFT FFT 是优化多项式乘法的,考虑如何将这道题转换成多项式乘法的形式
这道题是让求任意区间的子串当中某一字母的个数的平方和,求区间和要用前缀和的形式,然后再 s [ r ] − s [ l − 1 ] s[r] - s[l-1] s[r]−s[l−1] ,便是区间 [ l , r ] [l, r] [l,r] 的贡献值
等会,我发现了什么!!! s [ r ] − s [ l − 1 ] s[r] - s[l-1] s[r]−s[l−1] ,这是啥,敲重点,这是差值欸
那么,差值又能干什么
学过FFT的都知道,差值(也就是减法)能转换成加法,然后加法又能转换成乘法,系数相乘就能用 F F T FFT FFT 优化了
那么这题就要去考虑什么作为多项式的次数,什么又作为多项式的系数,题目要计算的是每一差值出现的次数,差值是由两个不同的 s [ i ] s[i] s[i] 产生的,显然 s [ i ] s[i] s[i] 是作为次数的,系数便是 s [ i ] s[i] s[i] 的个数
剩下便是FFT的板子了,还是有几个坑点要注意的
Code:
(代码是正解,但是标程要 O ( n ) O(n) O(n) 的算法,可怜我改了两三个小时的代码 T_T )
#include<bits/stdc++.h>
#define int long long
using namespace std;
const int N = (1<<20)+5, mo=998244353;
const double PI=acos(-1);
struct Complex
{
double x, y;
Complex operator+(const Complex &o) const{return{x+o.x,y+o.y};}
Complex operator-(const Complex &o) const{return{x-o.x,y-o.y};}
Complex operator*(const Complex &o) const{return{x*o.x-y*o.y,x*o.y+y*o.x};}
}A[N],B[N];
int rev[N];
void init(int k)
{
int s=1<<k;
for(int i=1;i<s;i++) rev[i]=(rev[i>>1]>>1)|((i&1)<<(k-1));
}
void fft(Complex *a,int n,int inv)
{
for(int i=0;i<n;i++) if(i<rev[i]) swap(a[i],a[rev[i]]);
for(int len=1;len<n;len<<=1)
{
Complex Wn=Complex({cos(PI/len),inv*sin(PI/len)});
for(int i=0;i<n;i+=len*2)
{
Complex w=Complex({1,0});
for(int j=0;j<len;j++,w=w*Wn)
{
Complex x=a[i+j],y=w*a[i+j+len];
a[i+j]=x+y,a[i+j+len]=x-y;
}
}
}
if(inv==-1) for(int i=0;i<n;i++) A[i].x = A[i].x/n+0.5; // 精度
}
// --------FFT
const int M=1e5+5;
char s[M];
int dp[33][M];
signed main()
{
int t;
scanf("%lld",&t);
while(t--)
{
scanf("%s",s+1);
int n=strlen(s+1);
//memset(dp,0,sizeof(dp));
for(int i=1;i<=n;i++)
{
for(int j=0;j<26;j++)
{
dp[j][i] = dp[j][i-1];
}
dp[s[i]-'a'][i]++;
}
int ans=0,sum=2,k=1;
while(sum <= n+n) k++,sum <<= 1;
init(k);
for(int num=0;num<26;num++)
{
//memset(A,0,sizeof(A));
//memset(B,0,sizeof(B));
// 无力的挣扎,试图省点时间
for(int i=0;i<sum;i++) A[i]=B[i]=Complex({0,0});
B[n].x=1;
// 这步de了一小时才想明白,傻了
for(int j=1;j<=n;j++)
{
A[dp[num][j]].x++;
B[n-dp[num][j]].x++;
}
if((int)A[0].x==n) continue; // 也是为了...
//if(B[n].x < 1) B[n].x=1;
//for(int i=0;i<=n;i++) cout<<A[i].x<<' '; cout<<endl;
//for(int i=0;i<=n;i++) cout<<B[i].x<<' '; cout<<endl;
fft(A,sum,1); fft(B,sum,1);
for(int i=0;i<sum;i++) A[i] = A[i]*B[i];
fft(A,sum,-1);
//for(int i=1;i<=n;i++) cout<<(int)A[i+n].x<<' '; cout<<endl;
for(int i=1;i<=n;i++) ans = (ans+(int)A[i+n].x*i*i)%mo;
//cout<<char('a'+num)<<' '<<ans<<endl;
}
printf("%lld\n",ans);
}
return 0;
}
B [ n ] . x = 1 B[n].x = 1 B[n].x=1 为啥要有这个初始化,因为所有的区间 [ 0 , i ] [0, i] [0,i] 都要考虑一下,不能忘了,就是这个 0 0 0 ,用了各种办法来处理都不行,比赛的时候思路太乱
能过样例的那一刻巨开心,结果…
这题加深了我对FFT问题的理解,主要是如何转换方面,多项式间的乘法,首先要先能转换到乘法或加法,然后就是考虑什么作为多项式次数,什么作为系数,还有就是边界问题的处理。
然后这题我还信了网上的鬼话, N T T NTT NTT 跑的比 F F T FFT FFT 快,下面是 N T T NTT NTT 的代码(其实速度上没多大区别)
#include<bits/stdc++.h>
#define int long long
using namespace std;
const int N = (1<<20)+5, mo=998244353;
int rev[N];
void init(int k)
{
int s=1<<k;
for(int i=1;i<s;i++) rev[i]=(rev[i>>1]>>1)|((i&1)<<(k-1));
}
int binpow(int a,int b)
{
int ans=1;
while(b)
{
if(b&1) ans = ans*a%mo;
a = a*a%mo;
b >>= 1;
}
return ans;
}
void ntt(int *a, int n, int inv)
{
for(int i = 0; i < n; i++)
if(i < rev[i]) swap(a[i], a[rev[i]]);
for(int mid = 1; mid < n; mid <<= 1) {
int Wn = binpow(3, (mo-1)/(mid<<1));
if(inv == -1) Wn = binpow(Wn, mo-2);
for(int j = 0; j < n; j += (mid << 1)) {
int w = 1;
for(int k = 0; k < mid; k++, w = (w * Wn) % mo) {
int x = a[j + k], y = w * a[j + k + mid] % mo;
a[j + k] = (x + y) % mo,
a[j + k + mid] = (x - y + mo) % mo;
}
}
}
if(inv == -1)
{
int fg=binpow(n, mo-2);
for(int i=0;i<n;i++) a[i] = a[i]*fg%mo;
}
}
const int M=1e5+5;
char s[M];
int dp[33][M], a[N], b[N];
signed main()
{
//freopen("1.in","r",stdin);
//freopen("1.out","w",stdout);
int t;
scanf("%lld",&t);
while(t--)
{
scanf("%s",s+1);
int n=strlen(s+1);
//memset(dp,0,sizeof(dp));
for(int i=1;i<=n;i++)
{
for(int j=0;j<26;j++)
{
dp[j][i] = dp[j][i-1];
}
dp[s[i]-'a'][i]++;
}
int ans=0,sum=2,k=1;
while(sum <= n+n) k++,sum <<= 1;
init(k);
for(int num=0;num<26;num++)
{
//memset(A,0,sizeof(A));
//memset(B,0,sizeof(B));
for(int i=0;i<sum;i++) a[i]=b[i]=0;
b[n] = 1;
for(int j=1;j<=n;j++)
{
a[dp[num][j]]++;
b[n-dp[num][j]]++;
}
if(a[0]==n) continue;
//if(B[n].x < 1) B[n].x=1;
//for(int i=0;i<=n;i++) cout<<A[i].x<<' '; cout<<endl;
//for(int i=0;i<=n;i++) cout<<B[i].x<<' '; cout<<endl;
ntt(a,sum,1); ntt(b,sum,1);
for(int i=0;i<sum;i++) a[i] = a[i]*b[i];
ntt(a,sum,-1);
//for(int i=1;i<=n;i++) cout<<(int)A[i+n].x<<' '; cout<<endl;
for(int i=1;i<=n;i++) ans = (ans+a[i+n]*i*i)%mo;
//cout<<char('a'+num)<<' '<<ans<<endl;
}
printf("%lld\n",ans);
}
return 0;
}