题意
对于串 S1 和 S2,求一种将串 S1 任意分割为 3 个子串 A, B, C ,满足 A+B+C=S1,同时通过改变排序能构成 S2 串。
解题思路
字符串 HASH
对 S1 串的分割位置进行枚举,将 S1 分割为 [1, i]
, [i+1, j]
, [j, length]
。通过判断 三个段合并的 hash 值是否等于 S2 串 HASH 可以避免绝大多数判断,若 HASH 值相等,判断重排序的串与 S2 是否完全相同。
代码
#include<bits/stdc++.h>
using namespace std;
const int N = 5000 + 10;
const int seed = 29;
char s[N], t[N]; // divided s in A|B|C three parts -> to recombin and compare with templateHash
unsigned int C[N], seedVal[N] = {1};
bool jug(int al, int ar, int bl, int br, int cl, int cr)
{
// printf("[%d, %d] [%d, %d] [%d, %d]\n", al, ar, bl, br, cl, cr);
for(int i=1;i<=ar-al+1;i++)
if(t[i] != s[al+i-1]) return false;
for(int i=1;i<=br-bl+1;i++)
if(t[ar-al+1+i] != s[bl+i-1]) return false;
for(int i=1;i<=cr-cl+1;i++)
if(t[br-bl+ar-al+2+i] != s[cl+i-1]) return false;
int arr[6] = {al, ar, bl, br, cl, cr};
sort(arr, arr+6);
printf("YES\n");
for(int idx=0, i=arr[idx];i <= strlen(s+1);i++) {
printf("%c", s[i]);
if(i == arr[idx+1]) { printf("\n"); idx += 2; }
}
}
void solve()
{
int n = strlen(s+1);
unsigned int templateHash = 0;
for(int i=1;i<=n;i++)
(templateHash *= seed) += t[i]-'a';
for(int i=1;i<=n;i++)
seedVal[i] = seedVal[i-1] * seed;
for(int i=n;i;i--)
C[i] = seedVal[n-i] * (s[i]-'a') + C[i+1];
// guess the left division
unsigned int A = 0, B, hashval;
for(int i=1;i<=n;i++)
{
(A *= seed) += s[i] - 'a';
B = 0;
for(int j=i+1;j<=n;j++)
{
(B *= seed) += s[j]-'a';
// A|B|C model
if(A * seedVal[n-i] + B * seedVal[n-j] + C[j+1] == templateHash && jug(1, i, i+1, j, j+1, n))
return;
if(A * seedVal[n-i] + C[j+1] * seedVal[j-i] + B == templateHash && jug(1, i, j+1, n, i+1, j))
return;
if(B * seedVal[n-j+i] + A * seedVal[n-j] + C[j+1] == templateHash && jug(i+1, j, 1, i, j+1, n))
return;
if(B * seedVal[n-j+i] + C[j+1] * seedVal[i] + A == templateHash && jug(i+1, j, j+1, n, 1, i))
return;
if(C[j+1] * seedVal[j] + A * seedVal[j-i] + B == templateHash && jug(j+1, n, 1, i, i+1, j))
return;
if(C[j+1] * seedVal[j] + B * seedVal[i] + A == templateHash && jug(j+1, n, i+1, j, 1, i))
return;
}
}
printf("NO\n");
}
int main()
{
scanf(" %s %s", s+1, t+1);
int n = strlen(s+1);
for(int i=1;i<=n;i++)
{
if(s[i] >= 'A' && s[i] <= 'Z')
s[i] = s[i] - 'A' + 'a';
if(t[i] >= 'A' && t[i] <= 'Z')
t[i] = t[i] - 'A' + 'a';
}
if(n < 3) printf("NO\n");
else solve();
}