考场上写的是一个常数较大的 O ( n ln n ) O(n \ln n) O(nlnn),但是犯了一些让自己惊讶的致命错误,挂了一堆分。以下是考场代码,并注释错误。
#include <bits/stdc++.h>
using namespace std;
#define int long long //1. 这个地方 4 倍常数,你不知道要卡时间吗
const int N = 1e6 + 5, Base = 27, p = 1e9 + 7; // 2. 你这个数组大小 2^20 绝对 RE 了 啊
int n;
unsigned long long hs[N], w[N], lty[N];
int sum[N][27]; //a - z
//int geths(int l, int r) {
// return (hs[r] - hs[l - 1]) * ksm(w[l], p - 2) % p;
//}
string s;
int orz[N], num[N], cnt[27], ac[N][27];
int ksm(int a, int b) {
int ans = 1;
while (b) {
if (b & 1) ans = (ans * a) % p;
a = (a * a) % p; b >>= 1;
}
return ans % p;
}
unsigned long long geths(int l, int r) {
if (l == 0) return hs[r];
return (hs[r] - hs[l - 1] + p) % p * ksm(w[l], p - 2) % p; // 3. 你为啥要在这里用 ksm 带 log 1e9 的常数?你白预处理了?
}
signed main() {
freopen("string.in", "r", stdin);
freopen("string.out", "w", stdout);
int Q; scanf("%lld", &Q);
w[0] = 1;
for (int i = 1; i <= 1e6; i++) w[i] = w[i - 1] * 23 % p, lty[i] = ksm(w[i], p - 2); // 4. 你这个 1e6 log 1e9 的预处理怎么想的?
//5. 预处理 O(n) 都不写,连指数的基本运算都不会了?
while (Q--) {
memset(sum, 0, sizeof(sum)); memset(orz, 0, sizeof(orz));
memset(cnt, 0, sizeof(cnt)); memset(num, 0, sizeof(num));
memset(ac, 0, sizeof(ac)); //6. 这样预处理是不是用些冗余?
cin >> s;
int len = s.size();
sum[0][s[0] - 'a']++;
for (int i = 1; i < len; i++) {
sum[i][s[i] - 'a']++;
for (int j = 0; j < 26; j++) sum[i][j] += sum[i - 1][j];
}
hs[0] = s[0] - 'a';
for (int i = 1; i < len; i++) hs[i] = (hs[i - 1] + w[i] * (s[i] - 'a') % p) % p;
for (int i = 0; i < len; i++) {
for (int j = 0; j < 26; j++)
if ((sum[len - 1][j] - sum[i][j]) & 1) orz[i]++;
}
int lyf = 0;
for (int i = 0; i < len; i++) {
cnt[s[i] - 'a']++;
if (cnt[s[i] - 'a'] & 1) lyf++;
else lyf--;
num[i] = lyf;
//cout << lyf << endl;
}
for (int j = 0; j <= 26; j++) {
ac[0][j] = (num[0] <= j); //7. 就这行代码你调那么长时间?数组下标从 1 开始不香吗?
for (int i = 1; i < len; i++)
ac[i][j] = ac[i - 1][j] + (num[i] <= j);
}
int ans = 0;
for (int i = 1; i < len; i++) {
unsigned long long t = geths(0, i);
for (int j = 0; j + i < s.size() - 1 /*No Error*/; j += i + 1) {
if (geths(j, j + i) == t) {
ans += ac[i - 1][orz[j + i]];
}
else break;
}
}
printf("%lld\n", ans);
}
return 0;
}
/*
3
nnrnnr
zzzaab
mmlmmlo
5
kkkkkkkkkkkkkkkkkkkk
lllllllllllllrrlllrr
cccccccccccccxcxxxcc
ccccccccccccccaababa
ggggggggggggggbaabab
*/
/*#include <bits/stdc++.h>
using namespace std;
#define int long long
const int N = 1e6 + 5, Base = 23, p = 1e9 + 7;
int n, hs[N], w[N], lty[N];
int sum[N][26]; //a - z
int ksm(int a, int b) {
int ans = 1;
while (b) {
if (b & 1) ans = (ans * a) % p;
a = (a * a) % p; b >>= 1;
}
return ans % p;
}
int geths(int l, int r) {
return (hs[r] - hs[l]) * ksm(w[l], p - 2) % p;
}
string s;
int orz[N];
signed main() {
cin >> s; int len = s.size();
w[0] = 1;
while (1){
int l, r, x, y; cin >> l >> r >> x >> y;
cout << geths(l, r) << " " << geths(x, y) << endl;
}
return 0;
}
/*
3
nnrnnr
zzzaab
mmlmmlo
5
kkkkkkkkkkkkkkkkkkkk
lllllllllllllrrlllrr
cccccccccccccxcxxxcc
ccccccccccccccaababa
ggggggggggggggbaabab
*/
改了之后只有 84 的代码
#include <bits/stdc++.h>
using namespace std;
#define int short
const int N = 1048577, p = 1e9 + 7;
int n;
long long hs[N], w[N], lty[N];
int sum[N][27];
char s[N];
int orz[N], num[N], cnt[27], ac[N][27];
long long ksm(long long a, long long b) {
long long ans = 1;
while (b) {
if (b & 1) ans = (ans * a) % p;
a = (a * a) % p; b >>= 1;
}
return ans % p;
}
long long geths(int l, int r) {
if (l == 0) return hs[r];
return (hs[r] - hs[l - 1] + p) % p * lty[l] % p;
}
signed main() {
int Q; scanf("%d", &Q);
w[0] = 1; lty[0] = 1;
long long wrnm = ksm(23, p - 2);
for (int i = 1; i <= 1048576; i++) w[i] = w[i - 1] * 23 % p, lty[i] = lty[i - 1] * wrnm % p;
while (Q--) {
memset(sum, 0, sizeof(sum));
memset(orz, 0, sizeof(orz));
memset(cnt, 0, sizeof(cnt));
memset(num, 0, sizeof(num));
memset(ac, 0, sizeof(ac));
scanf("%s", s);
int len = strlen(s);
sum[0][s[0] - 'a']++;
for (int i = 1; i < len; i++) {
sum[i][s[i] - 'a']++;
for (int j = 0; j < 26; j++) sum[i][j] += sum[i - 1][j];
}
hs[0] = s[0] - 'a';
for (int i = 1; i < len; i++) hs[i] = (hs[i - 1] + w[i] * (s[i] - 'a') % p) % p;
for (int i = 0; i < len; i++) {
for (int j = 0; j < 26; j++)
if ((sum[len - 1][j] - sum[i][j]) & 1) orz[i]++;
}
int lyf = 0;
for (int i = 0; i < len; i++) {
cnt[s[i] - 'a']++;
if (cnt[s[i] - 'a'] & 1) lyf++;
else lyf--;
num[i] = lyf;
}
for (int j = 0; j <= 26; j++) {
ac[0][j] = (num[0] <= j);
for (int i = 1; i < len; i++)
ac[i][j] = ac[i - 1][j] + (num[i] <= j);
}
long long ans = 0;
for (int i = 1; i < len; i++) {
long long t = geths(0, i);
for (int j = 0; j + i < len - 1; j += i + 1) {
if (geths(j, j + i) == t) {
ans += ac[i - 1][orz[j + i]];
}
else break;
}
}
printf("%lld\n", ans);
}
return 0;
}