4566: [Haoi2016]找相同字符
Time Limit: 20 Sec Memory Limit: 256 MBDescription
给定两个字符串,求出在两个字符串中各取出一个子串使得这两个子串相同的方案数。两个方案不同当且仅当这两
个子串中有一个位置不同。
Input
两行,两个字符串s1,s2,长度分别为n1,n2。1 <=n1, n2<= 200000,字符串中只有小写字母
Output
输出一个整数表示答案
Sample Input
aabb
bbaa
bbaa
Sample Output
10
首先最简单的方法是把两个字符串接在一起,中间插入间隔符,求出height数组;
答案是 所有属于A串的后缀 和 属于B串的后缀 的 LCP 求和
如果用 height 数组 + st表预处理 是 n^2的,这是不可接受的
所以我们可以尝试用单调栈来解决。
分别处理 A串 在 B串前 和 A串在B串后的情况。
处理到 排名为 i的串的时候,如果栈中的height > height[i], 就弹出
这样在栈顶和 当前串 之间 的串的height 必然 是 比 当前串 大的
所以这些串的贡献就等于 当前串的 height
而对于比栈顶串更靠前的串,因为栈顶串的height 是比当前串的height 小的,所以他们的贡献是在栈顶串之前就决定的,可以直接加上
用sum[i]求出区间内的有贡献串的数量就可以了。
1 #include <iostream> 2 #include <cstdio> 3 #include <cstring> 4 #include <algorithm> 5 #define LL long long 6 7 using namespace std; 8 9 const int MAXN = 4e5 + 10; 10 int n1, n2, n; 11 int m; 12 int sum[MAXN]; 13 char s[MAXN * 3]; 14 char s1[MAXN], s2[MAXN]; 15 int h[MAXN]; 16 LL ans = 0; 17 18 int SA[MAXN], ra[MAXN], cur[MAXN], tp[MAXN], c[MAXN]; 19 20 struct s { 21 int id; 22 int sum; 23 } sta[MAXN]; 24 25 inline LL read() 26 { 27 LL x = 0, w = 1; char ch = 0; 28 while(ch < '0' || ch > '9') { 29 if(ch == '-') { 30 w = -1; 31 } 32 ch = getchar(); 33 } 34 while(ch >= '0' && ch <= '9') { 35 x = x * 10 + ch - '0'; 36 ch = getchar(); 37 } 38 return x * w; 39 } 40 41 void solve(int x) 42 { 43 for(int i = 1; i <= x; i++) { 44 c[i] = 0; 45 } 46 for(int i = 1; i <= n; i++) { 47 c[ra[tp[i]]]++; 48 } 49 for(int i = 1; i <= x; i++) { 50 c[i] += c[i - 1]; 51 } 52 for(int i = n; i >= 1; i--) { 53 SA[c[ra[tp[i]]]--] = tp[i]; 54 } 55 } 56 57 void copy() 58 { 59 for(int i = 1; i <= n; i++) { 60 cur[i] = ra[i]; 61 } 62 } 63 64 void suffix() 65 { 66 for(int i = 1; i <= n; i++) { 67 ra[i] = char(s[i]); 68 tp[i] = i; 69 } 70 solve(m = 128); 71 for(int w = 1, p = 0; p < n; m = p, w += w) { 72 p = 0; 73 for(int j = n - w + 1; j <= n; j++) { 74 tp[++p] = j; 75 } 76 for(int i = 1; i <= n; i++) { 77 if(SA[i] > w) { 78 tp[++p] = SA[i] - w; 79 } 80 } 81 solve(m); 82 copy(); 83 ra[SA[1]] = p = 1; 84 for(int i = 2; i <= n; i++) { 85 if(cur[SA[i]] == cur[SA[i - 1]] && cur[SA[i] + w] == cur[SA[i - 1] + w]) { 86 ra[SA[i]] = p; 87 } else { 88 ra[SA[i]] = ++p; 89 } 90 } 91 } 92 int k = 0; 93 for(int i = 1; i <= n; i++) { 94 if(k) { 95 k--; 96 } 97 int j = SA[ra[i] - 1]; 98 while(s[i + k] == s[j + k]) { 99 k++; 100 } 101 h[ra[i]] = k; 102 } 103 } 104 105 void cal() 106 { 107 int top = 1; 108 sum[0] = 0; 109 for(int i = 1; i <= n; i++) { 110 sum[i] = sum[i - 1]; 111 if(SA[i] > n1 + 1) { 112 sum[i]++; 113 } 114 } 115 for(int i = 1; i <= n; i++) { 116 while(top > 1 && h[sta[top - 1].id] > h[i]) { 117 top--; 118 } 119 sta[top].sum = sta[top - 1].sum + (sum[i - 1] - sum[sta[top - 1].id - 1]) * h[i]; 120 sta[top++].id = i; 121 if(SA[i] <= n1) { 122 ans += sta[top - 1].sum; 123 } 124 } 125 sum[0] = 0; 126 top = 1; 127 for(int i = 1; i <= n; i++) { 128 sum[i] = sum[i - 1]; 129 if(SA[i] <= n1) { 130 sum[i]++; 131 } 132 } 133 for(int i = 1; i <= n; i++) { 134 while(top > 1 && h[sta[top - 1].id] > h[i]) { 135 top--; 136 } 137 sta[top].sum = sta[top - 1].sum + (sum[i - 1] - sum[sta[top - 1].id - 1]) * h[i]; 138 sta[top++].id = i; 139 if(SA[i] > n1 + 1) { 140 ans += sta[top - 1].sum; 141 } 142 } 143 } 144 int main() 145 { 146 scanf("%s", s1 + 1); 147 scanf("%s", s2 + 1); 148 n1 = strlen(s1 + 1), n2 = strlen(s2 + 1); 149 for(int i = 1; i <= n1; i++) { 150 s[i] = s1[i]; 151 } 152 s[n1 + 1] = char('z' + 1); 153 for(int i = 1; i <= n2; i++) { 154 s[i + n1 + 1] = s2[i]; 155 } 156 n = strlen(s + 1); 157 suffix(); 158 cal(); 159 printf("%lld\n", ans); 160 return 0; 161 } 162 163 /* 164 165 aabb 166 bbaa 167 168 */