【题解】P6793 [SNOI2020] 字符串(SAM)
D e s c r i p t i o n \rm Description Description
有两个长度为 n n n 的由小写字母组成的字符串 a , b a,b a,b,取出他们所有长为 k k k 的子串(各有 n − k + 1 n-k+1 n−k+1 个),这些子串分别组成集合 A , B A,B A,B 。现在要修改 A A A 中的串,使得 A A A 和 B B B 完全相同。可以任意次选择修改 A A A 中一个串的一段后缀,花费为这段后缀的长度。总花费为每次修改花费之和,求总花费的最小值。
输入格式
第一行两个整数
n
,
k
n,k
n,k 表示字符串长度和子串长度;
第二行一个小写字母字符串
a
a
a;
第三行一个小写字母字符串
b
b
b。
输出格式
输出一行一个整数表示总花费的最小值。
输入输出样例
输入
5 3
aabaa
ababa
输出
3
对于所有数据,
1
≤
k
≤
n
≤
1.5
×
1
0
5
1≤k≤n≤1.5×10^5
1≤k≤n≤1.5×105 。
S o l u t i o n \rm Solution Solution
由于修改的是子串的后缀,也就是前缀是相同的,问题可以转化为对集合 A A A 和 B B B 的子串进行匹配,使其公共前缀的长度和最大,即 ∑ i = 1 n l c p ( A i , B i ) \sum\limits_{i=1}^n lcp(A_i,B_i) i=1∑nlcp(Ai,Bi) 最大。
S
A
M
\rm SAM
SAM只能处理后缀,可以用反串建立
S
A
M
\rm SAM
SAM处理。
两个子串的最长公共后缀等于其在
p
a
r
e
n
t
\rm parent
parent
t
r
e
e
\rm tree
tree (也有人称为后缀链接)上的最近公共祖先(
L
C
A
\rm LCA
LCA )。
为了方便处理,将两个字符串
a
,
b
a,b
a,b 拼起来,记为
s
s
s ,用
s
s
s 的反串建立
S
A
M
\rm SAM
SAM ,再在
p
a
r
e
n
t
\rm parent
parent
t
r
e
e
\rm tree
tree 上区分
a
,
b
a,b
a,b 进行匹配。
记 s u m sum sum 为当前公共前缀的长度和,当匹配至节点 i i i 时,有 s 1 i s1_i s1i 个 a a a 的 l e n len len大于等于 k k k 的子串(由于求的是最长公共后缀,所以 l e n len len 不需要等于 k k k)、 有 s 2 i s2_i s2i 个 b b b 的 l e n len len 大于等于 k k k 的子串没有在 i i i 的儿子节点被匹配,有 m i n ( s 1 i , s 2 i ) min(s1_i,s2_i) min(s1i,s2i) 对 a , b a,b a,b 子串的 L C A \rm LCA LCA 为节点 i i i ,更新 s u m sum sum ,剩下的部分则上传到父亲节点继续匹配。
答案为长度为
k
k
k 的子串个数减去最大的公共前缀长度和,即
(
k
∗
(
n
−
k
+
1
)
−
s
u
m
(k*(n-k+1) -sum
(k∗(n−k+1)−sum 。
C o d e \rm Code Code
#include<bits/stdc++.h>
using namespace std;
long long n,k;
int t[1000005],a[1000005];
char s[500005];
struct node{
int nex[26];
int fa;
long long len,sum[2];
node(){memset(nex,0,sizeof(nex)); len=0;}
}d[1000005];
int tot=1,las=1;
long long ans;
void add(int c,int val,int jud)
{
int p=las,np=las=++tot; d[np].sum[jud]=val;
d[np].len=d[p].len+1;
for(; p && !d[p].nex[c]; p=d[p].fa) d[p].nex[c]=np;
if(!p) d[np].fa=1;
else
{
int q=d[p].nex[c];
if(d[q].len == d[p].len+1) d[np].fa=q;
else
{
int nq=++tot;
memcpy(d[nq].nex,d[q].nex,sizeof(d[nq].nex));
d[nq].fa=d[q].fa;
d[nq].len=d[p].len+1;
d[q].fa=d[np].fa=nq;
for(; p && d[p].nex[c]==q; p=d[p].fa) d[p].nex[c]=nq;
}
}
}
void tsort() //拓扑排序
{
for(int i=1; i<=tot; i++) t[d[i].len]++;
for(int i=1; i<=tot; i++) t[i]+=t[i-1];
for(int i=1; i<=tot; i++) a[t[d[i].len]--]=i;
}
int main()
{
scanf("%lld%lld",&n,&k);
scanf("%s",s);
for(int i=n-1; i>=0; i--) add(s[i]-'a',(i+k-1<n),0);
scanf("%s",s);
for(int i=n-1; i>=0; i--) add(s[i]-'a',(i+k-1<n),1);
int lens=n<<1;
tsort();
int mins;
for(int i=tot; i; i--)
{
mins=min(d[a[i]].sum[0],d[a[i]].sum[1]);
ans+=mins*min(k,d[a[i]].len);
d[a[i]].sum[0]-=mins;
d[a[i]].sum[1]-=mins;
d[d[a[i]].fa].sum[0]+=d[a[i]].sum[0];
d[d[a[i]].fa].sum[1]+=d[a[i]].sum[1];
}
ans=k*(n-k+1)-ans;
printf("%lld",ans);
return 0;
}
感谢阅读,如果有问题或更好的建议,还请提出。