hdu 4929 Another Letter Tree(LCA+DP)
题意:有一棵树n个节点(n<=50000),树上每个节点上有一个字母。m个询问(m<=50000),每次询问一个(a,b),问a节点到b节点的点不重复路径组成的字符串中子序列为s0的情况有多少种,s0长度小于等于30(注意s0是已经给定的,而不是每次询问都会给出一个新的)。
解法:一个很直观的想法,求出lca(设其为w)后,枚举x,求出a到w的路径上,能匹配s0的x长度前缀的情况有多少种,令其为c[x]。再求出b到w的路径上能匹配s0的L-x(L表示s0的长度)长度后缀的情况有多少种,令其为d[l-x],那么将所有的c[x]*d[l-x](x属于[0,l])加起来,即为答案(当然这里要考虑w这个点,不能同时出现在两部分当中,处理方法是w这个位置两部分都不要,然后在考虑w这个位置一定被选进去,两种情况加起来即可)。然后问题的难点在于,考虑某个节点u时,如何处理出c[i]与d[i]。这里,我们需要预处理一个dp数组,dp[i][j][u]表示,从u节点到根的路径匹配了s0[i,j]这段子串的子序列有多少种。那么c[i]就等于u到根的路径匹配了s0的i长度前缀情况数,减去有长度a的前缀在a到w的路径上(因为我们先考虑的是w两边都不要,这里其实我们要的是a到w的前一个节点的路径,c[i]考虑的也是这条路经)的情况数,即c[a](这里,因为我们是从小到大递推c[i],而a又小于i,故在求c[i]之前,我们必然已经推出过了c[a],直接拿来用),乘以s0[a+1,i]的子串匹配在w到根的路径上的序列的情况数(这个就是前面预处理的dp数组,拿来用即可)。求d[i]亦是同样地方法,这里时间复杂度主要是在预处理上。整体时间复杂度为n*l*l,问题得解。
代码:
#pragma comment(linker, "/STACK:1024000000,1024000000")
#include<stdio.h>
#include<string.h>
#include<algorithm>
#include<vector>
using namespace std ;
void get_num ( int& n ) {
n = 0 ;
char c ;
while ( c = getchar () ) {
if ( c >= '0' && c <= '9' ) break ;
}
n = c - '0' ;
while ( c = getchar () ) {
if ( c < '0' ¦¦ c > '9' ) break ;
n = n * 10 + c -'0' ;
}
}
const int maxn = 50005 ;
const int mod = 10007 ;
short dp[2][33][33][maxn] ;
int c[33] , d[33] ;
char s[maxn] , s1[33] ;
vector<int> vec[maxn] ;
int p[20][maxn] , fa[maxn] , deep[maxn] ;
struct LCA {
void dfs ( int u ) {
if ( u == 1 ) fa[u] = 0 ;
p[0][u] = fa[u] ;
deep[u] = deep[fa[u]] + 1 ;
for ( int i = 1 ; i < 20 ; i ++ ) p[i][u] = p[i-1][p[i-1][u]] ;
int sz = vec[u].size () ;
for ( int i = 0 ; i < sz ; i ++ ) {
int v = vec[u][i] ;
if ( v == fa[u] ) continue ;
fa[v] = u ;
dfs ( v ) ;
}
}
int father_k ( int u , int k ) {
for ( int i = 0 ; i < 20 ; i ++ )
if ( k & ( 1 << i ) )
u = p[i][u] ;
return u ;
}
int query ( int a , int b ) {
if ( deep[a] > deep[b] ) swap ( a , b ) ;
b = father_k ( b , deep[b] - deep[a] ) ;
if ( a == b ) return a ;
for ( int i = 19 ; i >= 0 ; i -- ) {
if ( fa[a] == fa[b] ) break ;
if ( p[i][a] != p[i][b] ) {
a = p[i][a] ;
b = p[i][b] ;
}
}
return fa[a] ;
}
} lca ;
int l ;
void dfs ( int u , int x , int c ) {
for ( int i = x ; i <= l ; i ++ ) {
dp[c][x][i][u] += dp[c][x][i][fa[u]] ;
if ( dp[c][x][i][u] >= mod ) dp[c][x][i][u] -= mod ;
if ( s[u] == s1[i] )
dp[c][x][i][u] += dp[c][x][i-1][fa[u]] ;
if ( dp[c][x][i][u] >= mod ) dp[c][x][i][u] -= mod ;
}
int sz = vec[u].size () ;
for ( int i = 0 ; i < sz ; i ++ ) {
int v = vec[u][i] ;
if ( v == fa[u] ) continue ;
dfs ( v , x , c ) ;
}
}
void DP ( int n , int c ) {
for ( int i = 0 ; i <= l + 1 ; i ++ ) {
for ( int j = 0 ; j <= n ; j ++ ) {
for ( int k = 0 ; k <= i ; k ++ )
dp[c][k][i][j] = 0 ;
if (i) dp[c][i][i-1][j] = 1 ;
}
}
for ( int i = 1 ; i <= l ; i ++ )
dfs ( 1 , i , c ) ;
}
int main () {
int T , n , q ;
scanf ( "%d" , &T ) ;
while ( T -- ) {
scanf ( "%d%d" , &n , &q ) ;
for ( int i = 1 ; i <= n ; i ++ ) vec[i].clear () ;
for ( int i = 1 ; i < n ; i ++ ) {
int a , b ;
get_num (a) ;
get_num (b) ;
vec[a].push_back (b) ;
vec[b].push_back (a) ;
}
scanf ( "%s" , s + 1 ) ;
scanf ( "%s" , s1 + 1 ) ;
l = strlen ( s1 + 1 ) ;
lca.dfs ( 1 ) ;
reverse ( s1 + 1 , s1 + l + 1 ) ;
DP ( n , 0 ) ;
reverse ( s1 + 1 , s1 + l + 1 ) ;
DP ( n , 1 ) ;
while ( q -- ) {
int a , b , x , y ;
get_num (a) ;
get_num (b) ;
if ( a == b ) {
if ( l == 1 && s[a] == s1[1] ) puts ( "1" ) ;
else puts ( "0" ) ;
continue ;
}
int w = lca.query ( a , b ) ;
int ans = 0 ;
memset ( c , 0 , sizeof ( c ) ) ;
memset ( d , 0 , sizeof ( d ) ) ;
for ( int i = 0 ; i <= l ; i ++ ) {
c[i] = dp[0][l-i+1][l][a] ;
d[i] = dp[1][l-i+1][l][b] ;
// printf ( "d[%d] = %d\n" , i , d[i] ) ;
for ( int j = 0 ; j < i ; j ++ ) {
c[i] -= (c[j] * dp[0][l-i+1][l-j][w] % mod) ;
d[i] -= (d[j] * dp[1][l-i+1][l-j][w] % mod) ;
c[i] += mod ;
if ( c[i] >= mod ) c[i] -= mod ;
d[i] += mod ;
if ( d[i] >= mod ) d[i] -= mod ;
}
// printf ( "c[%d] = %d , d[%d] = %d\n" , i , c[i] , i , d[i] ) ;
}
for ( int i = 0 ; i <= l ; i ++ ) {
ans += c[i] * d[l-i] % mod ;
if ( ans >= mod ) ans -= mod ;
}
for ( int i = 0 ; i < l ; i ++ ) {
if ( s[w] == s1[i+1] ) {
ans += c[i] * d[l-i-1] % mod ;
if ( ans >= mod ) ans -= mod ;
}
}
printf ( "%d\n" , ans ) ;
}
}
}
/*
1000
12 1000
1 2
1 3
2 4
2 5
2 6
5 9
5 10
3 7
3 8
8 11
8 12
abbaabbababb ba
8 6
2 10
10 2
1 2 9 0
1 2 10 0
1 2 9 1
1 2 10 1
*/