题目链接:
https://acm.hdu.edu.cn/showproblem.php?pid=7055
题目大意:
给一个字符串,对于每个小写字母ch,问所有区间内ch出现次数的平方和,其中|s|<
因为答案很大,所以对998244353取模(答案很大,忍一下
解题思路:
朴素算法,像公式那样,遍历每个区间,相加。但是因为长度是1e5肯定会TLE
对于这种区间问题,我们往往采用差分或者前缀和
这题赛场上就想到了用前缀和
对于一个字母ch,我们定义 s[i]为到i为止,ch出现的个数
那么每个区间[,r]的答案就是
我们将平方和拆开化简
就会得到
然后时间复杂度就会变成O(26n)
(可以统计字母出现的种类tot,时间复杂度优化到O(tot*n))赛场上卡常过了,赛后T了
代码如下:
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N = 1e5 + 3;
const int MOD = 998244353;
ll a[27][N];
int b[27], cc[27];
int main() {
int T;
scanf ("%d", &T);
while (T--) {
string s;
cin >> s ;
int len = s.length();
memset (a, 0, sizeof (a) );
memset (b, 0, sizeof (b) );
for (int i = 0; i < len; i++) {
b[s[i] - 'a' + 1] = 1;
a[s[i] - 'a' + 1][i]++;
}
int tot = 0;
for (int i = 1; i <= 26; i++)
if (b[i])
cc[++tot] = i;
ll ans = 0;
ll tt = 0, t = 0;
for (int k = 1; k <= tot; k++) {
tt = 0, t = 0;
int c = cc[k];
for (int i = 0; i < len; i++) {
a[c][i] = a[c][i - 1] + a[c][i];
t = (ll) (t + a[c][i] * a[c][i] % MOD) % MOD;
tt =
(ll) (tt + a[c][i]) % MOD;
}
ans += ( (len + 1) * t % MOD - tt * tt % MOD + MOD) % MOD;
ans %= MOD;
}
printf ("%lld\n", ans);
}
return 0;
}
我们可以想一想,怎么样改进
我们的思路是求前缀和的和,和前缀和平方和。
我们可以发现一段前缀和的值,在一个区间内是一样的,并且这个区间长度是可以知道的。而前缀和的平方和其实就是在区间内当前字母的个数*当前这一段的长度。
对于字a在字符串ababa举例:
i | 0 | 1 | 2 | 3 | 4 | 5 | 6 |
字母 | \ | a | b | a | b | a | \ |
维持长度 | 2 | 2 | 1 | ||||
s[i] | 0 | 1 | 1 | 2 | 2 | 3 |
对于前缀和的和
我们可以用 1 * 2 + 2 * 2 + 3 * 1 = 9 算
对于前缀和的平方和
我们可以用 来算
这上面的 2 2 1 其实是这一个ch对前缀和影响的长度
用a[i]来保存序列中第i个ch出现的位置
用len[i]表示第i个ch维持的长度
令区间内出现ch的次数为cnt
len[cnt]=n+1-a[cnt]
因此我们的答案就变成了
这样跑的就非常的快
代码实现:
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N = 1e5 + 3;
const int MOD = 998244353;
char s[N];
void work() {
scanf ("%s", s);
int len = strlen (s);
vector<int>a[27];
for (int i = 1; i <= 26; i++) a[i].push_back (0);
for (int i = 0; i < len; i++) {
a[s[i] - 'a' + 1].push_back (i + 1);
}
ll ans = 0;
for (int i = 1; i <= 26; i++) {
a[i].push_back (len + 1);
ll t1 = 0, t2 = 0;
for (int j = 1; j + 1 < a[i].size(); j++) {
t1 += (ll)j * (a[i][j + 1] - a[i][j]) % MOD;
t2 += (ll)j * j % MOD * (a[i][j + 1] - a[i][j]) % MOD;
t1 %= MOD;
t2 %= MOD;
}
ans += ( (ll) (len + 1) * t2 % MOD - t1 * t1 % MOD + MOD ) % MOD;
ans %= MOD;
}
printf ("%lld\n", ans);
}
int main() {
int T;
scanf ("%d", &T);
while (T--)
work();
return 0;
}