题意
给你 N N N个大写字母组成的单词和一串大写非元音字母所组成的字符串( L ≤ 3 × 1 0 5 L\leq 3\times 10^5 L≤3×105)。已知字符串 L L L是由给出单词的非元音字母组成的,求将这个字符串还原回正常单词以后,元音字母数量最多的一个原串。数据保证 ∑ i = 0 n l e n ( s i ) ≤ 1 0 5 \sum_{i=0}^{n} len(s_i) \leq 10^5 ∑i=0nlen(si)≤105。
解题思路
由于模式串
L
L
L和匹配串的数据很大,所以普通匹配肯定是不行的。可以考虑
A
C
AC
AC自动机或者Hash,但是我
A
C
AC
AC自动机写挂了所以改用
H
a
s
h
Hash
Hash了。同时匹配结束以后我们还需要用
D
P
DP
DP来获取元音字母最多的还原串。
假设我们已经获取每个位置上的匹配串,假设位置
i
i
i的所有匹配串为
m
i
m_i
mi,每个匹配串的元音字母数量为
v
j
v_j
vj,匹配长度为
l
j
l_j
lj。那么可得
D
P
DP
DP公式:
d
p
i
=
max
j
∈
m
i
{
d
p
i
−
l
j
+
v
j
+
v
j
}
dp_i=\max_{j \in m_i}\{dp_{i-l_j+v_j}+v_j\}
dpi=j∈mimax{dpi−lj+vj+vj}
那么
d
p
L
dp_L
dpL就是最多的元音数量。因为题目要输出这个还原串,所以我们在
D
P
DP
DP的时候还要记录每一个状态是由哪一个匹配串转移而来的,这样就可以从后往前构造原串。
如果我们考虑这个
D
P
DP
DP的时间复杂度,会发现它很有可能是
O
(
N
L
)
O(NL)
O(NL)的,显然会超时。但是仔细一想我们会发现,对于每一个位置,长度为
x
x
x的匹配串有且仅能有一个,那么我们其实只需要枚举长度就好了。对于一个总长度不超过
N
N
N的单词集合,长度的种类最多只有
O
(
N
)
O(\sqrt{N})
O(N)种。因为
∑
i
=
1
k
i
=
N
\sum_{i=1}^k{i}=N
∑i=1ki=N可得
k
(
k
+
1
)
2
=
N
\frac{k(k+1)}{2}=N
2k(k+1)=N,解得
k
=
2
N
+
0.25
−
0.5
k=\sqrt{2N+0.25}-0.5
k=2N+0.25−0.5。
如果每个位置只有最多
O
(
N
)
O(\sqrt{N})
O(N)个匹配串,那么总复杂度
O
(
L
N
)
O(L\sqrt{N})
O(LN)就可以接受了。
对于维护
N
\sqrt{N}
N个长度的
H
a
s
h
Hash
Hash我们也可以动态的
O
(
L
)
O(L)
O(L)时间来实现,只要在超出长度的时候
h
a
s
h
−
p
o
w
[
l
e
n
]
hash - pow[len]
hash−pow[len]就可以了。
时间复杂度
O ( L N ) O(L\sqrt{N}) O(LN)
代码
#include <algorithm>
#include <cmath>
#include <cstdio>
#include <cstring>
#include <iostream>
#include <map>
#include <queue>
#include <set>
#include <stack>
#include <vector>
typedef long long ll;
typedef unsigned long long ull;
using namespace std;
const int INF = 2147483647;
const int INF2 = 0x3f3f3f3f;
const ll INF64 = 1e18;
const double INFD = 1e30;
const double EPS = 1e-15;
const double PI = 3.14159265;
const ll MOD = 1e9 + 7;
template <typename T>
inline T read() {
T X = 0, w = 0;
char ch = 0;
while (!isdigit(ch)) {
w |= ch == '-';
ch = getchar();
}
while (isdigit(ch)) X = (X << 3) + (X << 1) + (ch ^ 48), ch = getchar();
return w ? -X : X;
}
const int MAXN = 300005;
int n, m, k;
int CASE;
// 单词字符串信息
struct Node2 {
string str;
ull hash1;
ll hash2;
int cnt;
};
// 长度离散化结构
struct Discretization {
vector<int> xp;
vector<int>::iterator xend;
int size() const { return xend - xp.begin(); }
void add(int val) { xp.push_back(val); }
void discretize() {
sort(xp.begin(), xp.end());
xend = unique(xp.begin(), xp.end());
}
int get(int val) {
return lower_bound(xp.begin(), xend, val) - xp.begin() + 1;
}
int get2(int num) { return xp[num - 1]; }
};
Discretization dis;
char text[MAXN];
// 字符串hash匹配表
vector<int> hashList[MAXN * 2];
// 原单词信息
Node2 strList[MAXN];
// 每个位置匹配的串
vector<int> matches[MAXN];
int dp[MAXN];
int opt[MAXN];
string removeVowel(const string& str, int& cnt) {
string res;
cnt = 0;
for (auto c : str) {
if (c == 'A' || c == 'E' || c == 'I' || c == 'O' || c == 'U') {
cnt++;
} else {
res.push_back(c);
}
}
return res;
}
// 我用了双hash,一个自然溢出一个模数
ull powh[MAXN];
ll powh2[MAXN];
struct Hasher {
ull hashB1;
ll hashB2;
int cnt1;
void init() {
powh[0] = powh2[0] = 1;
for (int i = 1; i < MAXN; i++) {
powh[i] = powh[i - 1] * 277LL;
powh2[i] = powh2[i - 1] * 1999LL % MOD;
}
}
Hasher() {}
void clear() {
hashB1 = 0;
hashB2 = 0;
}
void removel(char c, int len) {
hashB1 -= powh[len] * (ull)c;
hashB2 -= (powh2[len] * c) % MOD;
hashB2 %= MOD;
if (hashB2 < 0) hashB2 += MOD;
}
void insertr(char c) {
hashB1 *= powh[1];
hashB1 += c;
hashB2 = (hashB2 * powh2[1]) % MOD;
hashB2 += c;
hashB2 %= MOD;
}
};
Hasher hasher;
Hasher dynHash[505];
map<string, int> sss;
int main() {
#ifdef LOCALLL
freopen("in", "r", stdin);
freopen("out", "w", stdout);
#endif
scanf("%d", &n);
hasher.init();
int tot = 1;
for (int i = 1; i <= n; i++) {
scanf("%s", text);
string s(text);
int cnt;
string rem = removeVowel(s, cnt);
if (!sss.count(rem) || cnt > sss[rem]) {
sss[rem] = cnt;
hasher.clear();
for (auto a : rem) {
hasher.insertr(a);
}
strList[tot] = {s, hasher.hashB1, hasher.hashB2, cnt};
hashList[hasher.hashB1 % (MAXN * 2)].push_back(tot);
tot++;
dis.add(rem.size());
}
}
dis.discretize();
scanf("%s", text + 1);
n = strlen(text + 1);
for (int i = 1; i <= n; i++) {
dp[i] = -1;
int sz = dis.size();
for (int j = sz; j >= 1; j--) {
// 动态维护每个长度的hash
dynHash[j].insertr(text[i]);
int len = dis.get2(j);
if (i > len) {
dynHash[j].removel(text[i - len], len);
}
if (i >= len) {
int loc = dynHash[j].hashB1 % (MAXN * 2);
// hash匹配
for (auto v : hashList[loc]) {
Node2& nd = strList[v];
if (nd.hash2 == dynHash[j].hashB2) {
matches[i].push_back(v);
}
}
}
}
}
dp[0] = 0;
for (int i = 1; i <= n; i++) {
for (auto a : matches[i]) {
Node2& nd = strList[a];
int p = i - nd.str.size() + nd.cnt;
if (dp[p] >= 0) {
int nx = dp[p] + nd.cnt;
if (dp[i] < nx) {
dp[i] = nx;
opt[i] = a;
}
}
}
}
vector<int> res;
int k = n;
// 构造原串
while (k > 0) {
res.push_back(opt[k]);
int d = strList[opt[k]].str.size() - strList[opt[k]].cnt;
k -= d;
}
for (int i = res.size() - 1; i >= 0; i--) {
printf("%s", strList[res[i]].str.c_str());
if (i != 0) {
printf(" ");
}
}
printf("\n");
return 0;
}