题目大意:
给
k
k
k个字符串
t
1
,
t
2
,
.
.
.
t
k
t_1,t_2,...t_k
t1,t2,...tk,
t
i
t_i
ti有权值
c
i
c_i
ci.令
F
(
T
,
t
)
F(T,t)
F(T,t)表示字符串
T
T
T中包含多少个
t
t
t,
G
(
T
)
=
∑
i
=
1
k
F
(
T
,
t
i
)
∗
c
i
G(T)=\sum_{i=1}^kF(T,t_i)*c_i
G(T)=∑i=1kF(T,ti)∗ci。
现在给出一个字符串
S
S
S,
S
S
S中有最多14个位置是未知的,你可以在这些位置上填互不相同的字母
a
−
n
a-n
a−n,求
G
(
S
)
G(S)
G(S)最大可以是多少。
∑
∣
t
i
∣
≤
1000
,
∣
S
∣
≤
5
e
4
,
−
1
0
6
≤
c
i
≤
1
0
6
\sum |t_i|\le 1000, |S|\le5e4,-10^6\le c_i \le 10^6
∑∣ti∣≤1000,∣S∣≤5e4,−106≤ci≤106
解题思路
注意到未知的位置较少,且必须要填互不相同的字母,这提示我们用状压DP去写。
而统计一些模板字符在一个字符串里面出现的次数和贡献,可以使用ac自动机求出。在这题中的障碍是那些未知的位置。
注意到
∑
∣
t
i
∣
≤
1000
\sum|t_i|\le 1000
∑∣ti∣≤1000,AC自动机最多有1000个结点。未知位置最多有14个,所以原本的串
S
S
S最多被分成15段已知的固定的串。
我们令
n
x
t
[
u
]
[
i
]
nxt[u][i]
nxt[u][i]表示ac自动机的结点
u
u
u跑一遍
S
S
S的第
i
i
i段串之后变成了结点
n
x
t
[
u
]
[
i
]
nxt[u][i]
nxt[u][i]。令
s
u
m
[
u
]
[
i
]
sum[u][i]
sum[u][i]表示这个过程中得到的贡献。
我们用
d
p
[
u
]
[
m
a
s
k
]
,
(
假
设
m
a
s
k
中
的
1
的
个
数
为
c
n
t
)
dp[u][mask],(假设mask中的1的个数为cnt)
dp[u][mask],(假设mask中的1的个数为cnt)表示:
处理完前
c
n
t
cnt
cnt个未知位置,使用的字符集合为
m
a
s
k
mask
mask,当前位置为第
c
n
t
+
1
cnt+1
cnt+1段的最后一个字母,在ac自动机上的位置为结点
u
u
u的情况下,得到的G的最大值.
它的转移如图表示:
先枚举当前使用的字符集合mask,然后枚举上一段的结尾走到了ac自动机的u,根据第cnt个位置填什么字符来转移:
转移的时候有三段贡献:
- 前面的dp值
- 从上一段最后一个位置走到第cnt个’?’(填了i)得到的贡献
- 走到cnt+1段的最后一个位置的贡献
dp[ nxt[ch[u][i]][num] ][mask] =max(dp[ nxt[ch[u][i]][num] ][mask], dp[u][mask^(1<<i)]+cost[ch[u][i]]+sum[ch[u][i]][num]);
ac代码:
#include<bits/stdc++.h>
#define ll long long
#define lowbit(x) ((x)&(-(x)))
#define fors(i, a, b) for(int i = (a); i < (b); ++i)
using namespace std;
const int maxn = 4e5 + 50;
int ch[maxn][15], fail[maxn];
ll cost[maxn], rt, tot = 0;
void ins(char *s, int val){
int p = rt;
while(*s){
int x = *s - 'a';
if(!ch[p][x]) {
ch[p][x] = ++tot;
}
p = ch[p][x];
s++;
}
cost[p] += val;
}
queue<int> q;
void get_fail()
{
while(q.size()) q.pop();
for(int i = 0; i < 15; ++i)
if(ch[rt][i]) q.push(ch[rt][i]), fail[ch[rt][i]] = rt;
else ch[rt][i] = rt;
while(q.size()){
int cur = q.front(); q.pop();
for(int i = 0; i < 15; ++i){
if(ch[cur][i]) {
fail[ ch[cur][i] ] = ch[ fail[cur] ][i];
q.push(ch[cur][i]);
cost[ch[cur][i]] += cost[ fail[ ch[cur][i] ] ];
}
else ch[cur][i] = ch[fail[cur]][i];
}
}
}
char t[1050];
void init(){
tot = 0; rt = ++tot;
int n; scanf("%d", &n);
fors(i, 0, n){
int x;
scanf("%s%d", t, &x); ins(t, x);
}
get_fail();
}
char s[maxn];
int pos[20], cnt = 0;
int nxt[1050][17];
ll sum[1050][17];
ll dp[1050][1<<14];
int cal(int x){int res = 0; while(x) res++, x-=lowbit(x); return res;}
void sol(){
scanf("%s", s);
int n = strlen(s);
pos[cnt++] = -1;
fors(i, 0, n) if(s[i] == '?') pos[cnt++] = i;
pos[cnt] = n;
fors(i, 0, cnt){
fors(u, 1, tot+1){
int p = u;
fors(j, pos[i]+1, pos[i+1]){
p = ch[p][s[j]-'a'];
sum[u][i] += cost[p];
}nxt[u][i] = p;
}
}
memset(dp, 0xcf, sizeof dp);
dp[nxt[rt][0]][0] = sum[rt][0];
ll ans = -1e18;
if(cnt == 1) ans = sum[rt][0];//if there is no "?"
fors(mask, 1, (1<<14)){
int num = cal(mask);
if(num > cnt-1) continue;
fors(u, 1, tot+1){
fors(i, 0, 14){
if(mask>>i&1){
dp[ nxt[ch[u][i]][num] ][mask] =
max(dp[ nxt[ch[u][i]][num] ][mask], dp[u][mask^(1<<i)]+cost[ch[u][i]]+sum[ch[u][i]][num]);
if(num == cnt-1) {
ans = max(ans, dp[ nxt[ch[u][i]][num] ][mask]);
}
}
}
}
}
cout<<ans<<endl;
}
int main()
{
init();
sol();
}