给出两个字符串A,B和一个常数x,问能否从A中取不超过x段出来,将其按原来顺序拼接后可以组成B。
Solution
我们可以设f[i,j]表示A的前i个字符选了j段出来能拼成的B的最长前缀是多少。
用后缀数组+RMQ即可做到O(1)转移
Code
#include <bits/stdc++.h>
const int N = 200005;
int n,m,x,f[100005][105],b[N],c[N],d[N],rank[N*2],sa[N],rmq[N][21],bin[21],lg[N],s[N],height[N];
char s1[N],s2[N];
void get_sa(int n,int m)
{
for (int i = 1; i <= n * 2; i++)
rank[i] = 0;
for (int i = 1; i <= m; i++)
b[i] = 0;
for (int i = 1; i <= n; i++)
b[s[i]]++;
for (int i = 1; i <= m; i++)
b[i] += b[i-1];
for (int i = n; i >= 1; i--)
c[b[s[i]]--] = i;
int t = 0;
for (int i = 1; i <= n; i++)
{
if (s[c[i]] != s[c[i - 1]])
t++;
rank[c[i]] = t;
}
int j = 1;
while (j <= n)
{
for (int i = 1; i <= n; i++) b[i] = 0;
for (int i = 1; i <= n; i++) b[rank[i + j]]++;
for (int i = 1; i <= n; i++) b[i] += b[i - 1];
for (int i = n; i >= 1; i--) c[b[rank[i + j]]--] = i;
for (int i = 1; i <= n; i++) b[i] = 0;
for (int i = 1; i <= n; i++) b[rank[i]]++;
for (int i = 1; i <= n; i++) b[i] += b[i - 1];
for (int i = n; i >= 1; i--) d[b[rank[c[i]]]--] = c[i];
t = 0;
for (int i = 1; i <= n; i++)
{
if (rank[d[i]] != rank[d[i - 1]] || rank[d[i]] == rank[d[i - 1]] && rank[d[i] + j] != rank[d[i - 1] + j])
t++;
c[d[i]] = t;
}
for (int i = 1; i <= n; i++)
rank[i] = c[i];
if (t == n)
break;
j <<= 1;
}
for (int i = 1; i <= n; i++)
sa[rank[i]] = i;
}
void get_height(int n)
{
int k = 0;
for (int i = 1; i <= n; i++)
{
if (k)
k--;
int j = sa[rank[i] - 1];
while (j + k <= n && i + k <= n && s[j + k] == s[i + k])
k++;
height[rank[i]] = k;
}
}
int get_rmq(int n)
{
for (int i = 1; i <= n; i++)
rmq[i][0] = height[i], lg[i] = log(i) / log(2);
bin[0] = 1;
for (int i = 1; i <= lg[n]; i++)
bin[i] = bin[i - 1] * 2;
for (int j = 1; j <= lg[n]; j++)
for (int i = 1; i + bin[j] - 1 <= n; i++)
rmq[i][j] = std::min(rmq[i][j - 1], rmq[i + bin[j - 1]][j - 1]);
}
int get_lcp(int l,int r)
{
l = rank[l]; r = rank[r];
if (l > r)
std::swap(l,r);
l++;
int len = lg[r - l + 1];
return std::min(rmq[l][len], rmq[r - bin[len] + 1][len]);
}
int main()
{
int T;
scanf("%d",&T);
while (T--)
{
scanf("%d%d%d",&n,&m,&x);
scanf("%s%s", s1 + 1, s2 + 1);
for (int i = 1; i <= n; i++)
s[i] = s1[i] - 'a' + 1;
s[n + 1] = 100;
for (int i = 1; i <= m; i++)
s[i + n + 1] = s2[i] - 'a' + 1;
get_sa(n + m + 1, 100);
get_height(n + m + 1);
get_rmq(n + m + 1);
for (int i = 1; i <= n; i++)
for (int j = 1; j <= x; j++)
f[i][j] = 0;
for (int j = 0; j < x; j++)
for (int i = 0; i < n; i++)
{
if (f[i][j] == m)
break;
f[i + 1][j] = std::max(f[i + 1][j],f[i][j]);
int len = get_lcp(i + 1, f[i][j] + 1 + n + 1);
f[i + len][j + 1] = std::max(f[i + len][j + 1], f[i][j] + len);
}
int ans = 0;
for (int i = 1; i <= n; i++)
for (int j = 1; j <= x; j++)
ans = std::max(ans,f[i][j]);
puts(ans == m ? "YES" : "NO");
}
}