有 m m m个字符串,每个字符串有一定的分值(可能为负数)
求出一个长 n n n的字符串 s s s使得它的价值最大,你只需要输出这个最大的价值.
价值定义为 ∑ i = 1 m c i i ∗ p o i n t i \sum\limits_{i=1}^m ci_i*point_i i=1∑mcii∗pointi
其中 c i i ci_i cii表示给定的第 i i i个字符串在 s s s中的出现次数, p o i n t i point_i pointi表示第 i i i个字符串的分值
对 m m m个字符串建 a c ac ac自动机,每个节点预处理一个 p o i n t i point_i pointi表示以这个节点结尾的所有串的分值和
当然这个 p o i n t point point要在 f a i l fail fail树上推标记累加
容易想到定义 f [ i ] [ j ] f[i][j] f[i][j]表示长度为 i i i的字符串在 j j j节点的最大分值
转移方程一目了然(满足 q q q节点到 j j j节点有边)
f [ i ] [ j ] = max q − > j { f [ i − 1 ] [ q ] + p o i n t j } f[i][j]=\max_{q->j}\{f[i-1][q]+point_j\} f[i][j]=q−>jmax{f[i−1][q]+pointj}
然而 n < = 1 0 9 n<=10^9 n<=109,一般来说可以使用矩阵快速幂,但这里是 max \max max运算
初始矩阵为,长度是自动机上的节点个数,也就是相当于 f [ 0 ] [ 0 ] , f [ 0 ] [ 1 ] . . . f [ 0 ] [ i d ] f[0][0],f[0][1]...f[0][id] f[0][0],f[0][1]...f[0][id]的值
[ 0 − i n f ⋯ − i n f ] \begin{bmatrix} 0 & -inf & \cdots & -inf \\ \end{bmatrix} [0−inf⋯−inf]
转移矩阵为 i d ∗ i d id*id id∗id的矩阵(下面用 n n n表示 i d id id)
S = [ x 11 x 12 ⋯ x 1 n x 21 x 22 ⋯ x 2 n ⋮ ⋮ ⋱ ⋮ x n 1 x n 2 ⋯ x n n ] S= \begin{bmatrix} x_{11} & x_{12} & \cdots & x_{1n} \\ x_{21} & x_{22} & \cdots & x_{2n} \\ \vdots & \vdots & \ddots & \vdots \\ x_{n1} & x_{n2} & \cdots\ & x_{nn} \\ \end{bmatrix} S=⎣⎢⎢⎢⎡x11x21⋮xn1x12x22⋮xn2⋯⋯⋱⋯ x1nx2n⋮xnn⎦⎥⎥⎥⎤
其中 x i , j x_{i,j} xi,j表示,若 i i i节点到 j j j节点有边, x i , j = m x j x_{i,j}=mx_j xi,j=mxj
若 i i i节点到 j j j节点无边, x i , j = − i n f x_{i,j}=-inf xi,j=−inf代表不能转移
这是一个外层 max \max max运算,内层 + + +运算的矩阵乘法
#include <bits/stdc++.h>
using namespace std;
using ll = long long;
const int maxn = 1e6+10;
const ll inf = 1e17;
int n,m,mx;
struct rce
{
ll m[201][201];
rce(){ memset( m,-0x3f3f3f3f,sizeof m); }
};
rce operator * ( rce a, rce b )
{
rce ans;
for(int i=0;i<=mx;i++)
for(int j=0;j<=mx;j++)
for(int k=0;k<=mx;k++)
ans.m[i][j] = max( ans.m[i][j],a.m[i][k]+b.m[k][j] );
return ans;
}
int zi[maxn][30],fail[maxn],id = 1;
ll point[maxn];
char a[maxn];
void insert(char a[],int val)
{
int len = strlen( a+1 ), now = 0;
for(int i=1;i<=len;i++)
{
if( !zi[now][a[i]-'a'] ) zi[now][a[i]-'a'] = ++id;
now = zi[now][a[i]-'a'];
}
point[now] += val;
}
void make_fail()
{
queue<int>q;
for(int i=0;i<=25;i++)
if( zi[0][i] ) q.push( zi[0][i] );
while( !q.empty() )
{
int u = q.front(); q.pop();
point[u] += point[fail[u]];
for(int i=0;i<=25;i++)
{
int v = zi[u][i];
if( v )
fail[v] = zi[fail[u]][i], q.push( v );
else zi[u][i] = zi[fail[u]][i];
}
}
}
rce quick(rce x,int n)
{
rce ans = x;
for( ; n ; n>>=1,x=x*x )
if( n&1 ) ans = ans*x;
return ans;
}
int main()
{
cin >> n >> m;
for(int i=1;i<=m;i++)
{
int val;
scanf("%s%d",a+1,&val );
insert( a,val );
}
make_fail(); mx = id;
rce z;
for(int i=0;i<=id;i++)
for(int j=0;j<=25;j++)
{
int v = zi[i][j];
z.m[i][v] = point[v];
}
z = quick( z,n-1 );
ll ans = -inf;
for(int i=0;i<=mx;i++) ans = max( ans,z.m[0][i] );
cout << ans;
}