题目大意是:给出m个模式串,每个模式串都有一个价值,出现一次就累加一次价值,输出价值最大的字符串,若价值相同,就输出最短的一个,若长度也相同,就输出字典序最小的一个。
我们可以把字符串的价值加在trie的节点上,用dp[i][j]表示长度为i,走到j号节点最大价值是多少。插入时要把模式串翻转,以便找出字典序最小
#include<stdio.h>
#include<string.h>
#include<algorithm>
#include<queue>
#define ll __int64
using namespace std ;
ll dp[111][1111] ;
int pre[111][1111] , n , m ;
int c[26][1111] ;
int tot ;
int d[1111] ;
int ty[1111] ;
int fail[1111] ;
queue<int> Q ;
struct AC_auto
{
int new_node ( int t )
{
int i ;
d[tot] = 0 ;
ty[tot] = t ;
fail[tot] = 0 ;
for ( i = 0 ; i < 26 ; i ++ ) c[i][tot] = 0 ;
return tot ++ ;
}
void init () { tot = 0 ; while ( !Q.empty () ) Q.pop () ; new_node ( 0 ) ; }
void insert ( char *s , int v )
{
int now = 0 ;
for ( ; *s ; s ++ )
{
int k = (*s) - 'a' ;
if ( !c[k][now] ) c[k][now] = new_node ( k ) ;
now = c[k][now] ;
}
d[now] = v ;
}
void get_fail ()
{
int i , j , u = 0 ;
for ( i = 0 ; i < 26 ; i ++ )
if ( c[i][u] ) Q.push ( c[i][u] ) ;
while ( !Q.empty () )
{
u =Q.front () ;
Q.pop () ;
for ( i = 0 ; i < 26 ; i ++ )
{
if ( !c[i][u] )
{
c[i][u] = c[i][fail[u]] ;
continue ;
}
int e = c[i][u] ;
j = fail[u] ;
if ( j && !c[i][j] ) j = fail[j] ;
fail[e] = c[i][j] ;
d[e] += d[fail[e]] ;
Q.push ( e ) ;
}
}
}
void work ()
{
int i , j , k ;
for ( i = 0 ; i <= n ; i ++ )
for ( j = 0 ; j < tot ; j ++ )
dp[i][j] = -1 , pre[i][j] = 0 ;
dp[0][0] = 0 ;
for ( i = 0 ; i < n ; i ++ )
for ( j = 0 ; j < tot ; j ++ )
{
if ( dp[i][j] == -1 ) continue ;
for ( k = 0 ; k < 26 ; k ++ )
{
int e = c[k][j] ;
if ( dp[i+1][e] < ( ll ) dp[i][j] + d[e] )
{
dp[i+1][e] = ( ll ) dp[i][j] + d[e] ;
pre[i+1][e] = j ;
}
if ( dp[i+1][e] == ( ll ) dp[i][j] + d[e] )
{
int p = j , q = pre[i+1][e] ;
int t = i ;
while ( ty[p] == ty[q] && t > 0 )
{
p = pre[t][p] , q = pre[t][q] ;
t -- ;
}
if ( ty[p] < ty[q] ) pre[i+1][e] = j ;
}
}
}
}
void ans ()
{
int i , j ;
int ans = -1 ;
int px = -1 , py = -1 ;
for ( i = 0 ; i <= n ; i ++ )
for ( j = 0 ; j < tot ; j ++ )
{
if ( dp[i][j] == -1 ) continue ;
if ( dp[i][j] > ans )
{
ans = dp[i][j] ;
px = i , py = j ;
}
else if ( dp[i][j] == ans && i == px )
{
int p = py , q = j ;
int t = i ;
while ( ty[p] == ty[q] && t > 0 )
{
p = pre[t][p] , q = pre[t][q] ;
t -- ;
}
if ( ty[p] > ty[q] ) py = j ;
}
}
while ( px > 0 )
{
printf ( "%c" , ty[py] + 'a' ) ;
py = pre[px][py] ;
px -- ;
}
puts ( "" ) ;
}
} ac ;
char s[111][15] ;
int v[111] ;
int main ()
{
int i , j , cas ;
scanf ( "%d" , &cas ) ;
while ( cas -- )
{
scanf ( "%d%d" , &n , &m ) ;
ac.init () ;
for ( i = 0 ; i < m ; i ++ )
{
scanf ( "%s" , s[i] ) ;
reverse ( s[i] , s[i] + strlen ( s[i] ) ) ;
}
for ( i = 0 ; i < m ; i ++ ) scanf ( "%d" , v + i ) ;
for ( i = 0 ; i < m ; i ++ )
ac.insert ( s[i] , v[i] ) ;
ac.get_fail () ;
ac.work () ;
ac.ans () ;
}
}