题意:有n个单词,每个单词长度为k,顺时针将它们写成一个圆圈串。现在知道g个长度为k的单词,是否可以从这g个单词中选择n个形成这个圆圈串?如果有多个答案,任意输出一个。
思路
可以发现,如果枚举第一个串的开始位置,由于输入的g个串,长度都为k,那么每个串的位置就固定了。那么我只要知道,在主串上那一段位置的字符串,是否存在在g个串中,然后如果每个都存在,那么就是符合的。这一部分可以用字符串hash做到O(1)判断。
复杂度的话,枚举第一个串的复杂度是k,一共需要匹配n个字符串,所以总复杂度是O(nk)
要注意不能用自然溢出写,因为可以造出数据卡掉自然溢出,这里我是用双哈希写的。
字符串Hash
hash[i] = (hash[i-1]*p + idx(i)) % mod;
字符串子串Hash
hash[l,r] = (hash[r] - hash[l-1]*p^(r-l+1) + mod) % mod;
在这里,p和mod是两个质数。
AC代码
#include <cstdio>
#include <cmath>
#include <cctype>
#include <bitset>
#include <algorithm>
#include <cstring>
#include <utility>
#include <string>
#include <iostream>
#include <map>
#include <set>
#include <vector>
#include <queue>
#include <stack>
using namespace std;
#pragma comment(linker, "/STACK:1024000000,1024000000")
#define eps 1e-10
#define inf 0x3f3f3f3f
#define pii pair<int, int>
typedef long long LL;
const double Pi = acos(-1.0);
const int maxn = 2e6 + 5;
const int seed = 1007;
const int mod[2] = {1000000007, 1000000009};
char s[maxn];
int Hash[2][maxn], ans[maxn];
int f[2][maxn], vis[maxn], tim;
map<pii, int>arr;
// hash[i] = (hash[i-1]*p + idx(i)) % mod;
// hash[l,r] = (hash[r] - hash[l-1]*p^(r-l+1)) % mod;
void init() {
for(int i = 0; i < 2; ++i) {
f[i][0] = 1;
for(int j = 1; j < maxn; ++j) {
f[i][j] = (LL)f[i][j-1] * seed % mod[i];
}
}
}
pii getHash(char *s) {
int res[2] = {0, 0};
int n = strlen(s);
for(int i = 0; i < 2; ++i) {
for(int j = 0; j < n; ++j) {
res[i] = ((LL)res[i] * seed + s[j]) % mod[i];
}
}
return pii(res[0], res[1]);
}
pii getHash(int l, int r) {
int res[2];
for(int i = 0; i < 2; ++i) {
res[i] = (Hash[i][r] - (LL)Hash[i][l-1] * f[i][r-l+1] % mod[i] + mod[i]) % mod[i];
}
return pii(res[0], res[1]);
}
void getPre(char *s, int k) {
int n = strlen(s+1);
for(int i = 0; i < 2; ++i) {
Hash[i][0] = 0;
for(int j = 1; j <= n; ++j) {
Hash[i][j] = ((LL)Hash[i][j-1] * seed + s[j]) % mod[i];
}
for(int j = n+1, t = 1; t < k ; ++t, ++j)
Hash[i][j] = ((LL)Hash[i][j-1] * seed + s[t]) % mod[i];
}
}
bool check(int st, int n, int k) {
tim++;
int l = st, r = st + k - 1;
for(int i = 1; i <= n; i++, l += k, r += k) {
pii p = getHash(l, r);
if(!arr.count(p)) return 0;
int val = arr[p];
if(vis[val] == tim) return 0; //判重
vis[val] = tim;
ans[i] = val;
}
printf("YES\n");
for(int i = 1; i <= n; ++i) {
printf("%d%c", ans[i], i == n ? '\n' : ' ');
}
return 1;
}
int main() {
init();
int n, k;
while(scanf("%d%d", &n, &k) == 2) {
tim = 1;
arr.clear();
scanf("%s", s+1);
getPre(s, k);
int g;
scanf("%d", &g);
for(int i = 1; i <= g; ++i) {
scanf("%s", s);
pii p = getHash(s);
arr[p] = i;
}
bool ok = 0;
for(int i = 1; i <= k; ++i) {
if(check(i, n, k)) {
ok = 1;
break;
}
}
if(!ok) printf("NO\n");
}
return 0;
}
如有不当之处欢迎指出!