4598: [Sdoi2016]模式字符串
Time Limit: 10 Sec Memory Limit: 128 MBSubmit: 80 Solved: 37
[ Submit][ Status][ Discuss]
Description
给出n个结点的树结构T,其中每一个结点上有一个字符,这里我们所说的字符只考虑大写字母A到Z,再给出长度为m
的模式串s,其中每一位仍然是A到z的大写字母。Alice希望知道,有多少对结点<u,v>满足T上从u到V的最短路径
形成的字符串可以由模式串S重复若干次得到?这里结点对<u,v>是有序的,也就是说<u,v>和<v,u>需要被区分.
所谓模式串的重复,是将若干个模式串S依次相接(不能重叠).例如当S=PLUS的时候,重复两次会得到PLUSPLUS,
重复三次会得到PLUSPLUSPLUS,同时要注恿,重复必须是整数次的。例如当S=XYXY时,因为必须重复整数次,所以X
YXYXY不能看作是S重复若干次得到的。
Input
每一个数据有多组测试,
第一行输入一个整数C,表示总的测试个数。
对于每一组测试来说:
第一行输入两个整数,分别表示树T的结点个数n与模式长度m。结点被依次编号为1到n,
之后一行,依次给出了n个大写字母(以一个长度为n的字符串的形式给出),依次对应树上每一个结点上的字符(
第i个字符对应了第i个结点).
之后n-1行,每行有两个整数u和v表示树上的一条无向边,之后一行给定一个长度为m的由大写字母组成的字符串,
为模式串S。
1<=C<=10,3<=N<=10000003<=M<=1000000
Output
给出C行,对应C组测试。每一行输出一个整数,表示有多少对节点<u,v>满足从u到v的路径形成的字符串恰好是模
式串的若干次重复.
Sample Input
1
11 4
IODSSDSOIOI
1 2
2 3
3 4
1 5
5 6
6 7
3 8
8 9
6 10
10 11
SDOI
11 4
IODSSDSOIOI
1 2
2 3
3 4
1 5
5 6
6 7
3 8
8 9
6 10
10 11
SDOI
Sample Output
5
HINT
数据文件太过巨大,仅提供前三组数据测试.
Source
统计树上所有路径,自然想到点分治,然后是字符串比较,可以用hash的办法
O(Tnlogn)
点分治又打残了。。。记得下传的时候fa要改成当前的重心啊。。。。。惨
还有特判m == 1的时候。。不知原题有没有这样的数据
#include<iostream>
#include<cstring>
#include<cstdio>
#include<algorithm>
#include<vector>
#include<stack>
using namespace std;
typedef unsigned long long UL;
typedef long long LL;
const UL p = 233;
const int maxn = 1E6 + 10;
int n,m,T,Len,O,Max,siz[maxn],c1[maxn],c2[maxn];
char s[maxn],ID[maxn];
LL Ans;
UL mi[maxn],h1[maxn],h2[maxn];
bool Huge[maxn];
vector <int> v[maxn];
stack <int> stk;
void Dfs2(int x,int fa,int tot)
{
int ma = 0; siz[x] = 1;
for (int i = 0; i < v[x].size(); i++)
{
int to = v[x][i];
if (to == fa || Huge[to]) continue;
Dfs2(to,x,tot);
ma = max(ma,siz[to]);
siz[x] += siz[to];
}
ma = max(ma,tot - siz[x]);
if (ma < Max) Max = ma,O = x;
}
void Calc(int x,int fa,int L,UL A,UL B)
{
if (L > Len) return;
if (L + 1 <= Len)
{
A += mi[L]*(UL)(ID[x]);
if (A == h1[L+1])
{
int res = ((L+1) % m == 0)?m:(L+1) % m;
Ans += 1LL*c2[m - res];
}
}
B += mi[L-1]*(UL)(ID[x]);
if (B == h2[Len - L + 1])
{
int res = (L % m == 0)?m:L % m;
Ans += 1LL*c1[m - res];
}
for (int i = 0; i < v[x].size(); i++)
{
int to = v[x][i];
if (to == fa || Huge[to]) continue;
Calc(to,x,L + 1,A,B);
}
}
void Add(int x,int fa,int L,UL A,UL B)
{
if (L > Len) return;
if (L + 1 <= Len)
{
A += mi[L]*(UL)(ID[x]);
if (A == h1[L+1])
{
int res = (L+1) % m; ++c1[res];
if (c1[res] == 1) stk.push(res);
}
}
B += mi[L-1]*(UL)(ID[x]);
if (B == h2[Len - L + 1])
{
int res = L % m; ++c2[res];
if (c2[res] == 1) stk.push(res);
}
for (int i = 0; i < v[x].size(); i++)
{
int to = v[x][i];
if (to == fa || Huge[to]) continue;
Add(to,x,L + 1,A,B);
}
}
void Dfs(int x,int fa,int tot)
{
Max = maxn; Dfs2(x,fa,tot);
int o = O; Huge[o] = 1; c2[0] = 1;
if (ID[o] == s[1])
{
if (m > 1) ++c1[1],stk.push(1);
else ++c1[0],stk.push(0);
}
for (int i = 0; i < v[o].size(); i++)
{
int to = v[o][i];
if (to == fa || Huge[to]) continue;
Calc(to,o,1,ID[o],0);
Add(to,o,1,ID[o],0);
}
while (!stk.empty())
{
int g = stk.top();
c1[g] = c2[g] = 0;
stk.pop();
}
for (int i = 0; i < v[o].size(); i++)
{
int to = v[o][i];
if (to == fa || Huge[to]) continue;
Dfs(to,o,siz[to]);
}
}
void Clear()
{
for (int i = 1; i <= n; i++)
v[i].clear(),Huge[i] = 0; Ans = 0;
}
int getint()
{
char ch = getchar(); int ret = 0;
while (ch < '0' || '9' < ch) ch = getchar();
while ('0' <= ch && ch <= '9')
ret = ret*10 + ch - '0',ch = getchar();
return ret;
}
void Solve()
{
n = getint(); m = getint();
scanf("%s",ID + 1); Clear();
for (int i = 1; i < n; i++)
{
int x = getint(),y = getint();
v[x].push_back(y);
v[y].push_back(x);
}
scanf("%s",s + 1);
if (m == 1)
{
for (int i = 1; i <= n; i++)
if (ID[i] == s[1]) Ans += 1LL;
}
int ti = n / m; Len = ti*m; h2[Len+1] = 0;
for (int i = 1; i < ti; i++)
for (int j = 1; j <= m; j++)
s[j + i*m] = s[j];
for (int i = 1; i <= Len; i++)
h1[i] = h1[i-1]*p + (UL)(s[i]);
for (int i = Len; i; i--)
h2[i] = h2[i+1]*p + (UL)(s[i]);
Dfs(1,0,n); printf("%lld\n",Ans);
}
int main()
{
#ifdef DMC
freopen("DMC.txt","r",stdin);
#endif
T = getint();
mi[0] = 1; for (int i = 1; i < maxn; i++) mi[i] = mi[i-1]*p;
while (T--) Solve();
return 0;
}