CodeForces 718E Matvey’s Birthday
题目大意
今天与 CF 的连接怎么这么稳定???
给定一个长度为 N N N的字符串 s s s,字符集为小写字母 a a a到 h h h,我们可以按照如下方式构造出一个无向图:
- 若 ∣ i − j ∣ ≤ 1 |i-j|\le 1 ∣i−j∣≤1,则在点 i i i和点 j j j之间连一条长度为 1 1 1的边;
- 若 s i = s j s_i=s_j si=sj,则在点 i i i和点 j j j之间连一条长度为 1 1 1的边。
若 d ( i , j ) d(i,j) d(i,j)为 i , j i,j i,j之间的最短路径,则定义这张图的直径为 max 1 ≤ i , j ≤ N , i ≠ j { d ( i , j ) } \max_{1\le i,j\le N,i\neq j}\{d(i,j)\} max1≤i,j≤N,i=j{d(i,j)}。
求这张图的直径和 d ( i , j ) d(i,j) d(i,j)等于直径的点对数量。
分析
结论: d ( i , j ) d(i,j) d(i,j)是不会超过 15 15 15的。
证明: 我们可以发现,在从 i i i到 j j j的最短路径上,一种字母一定不会出现超过两次。若出现了多于两次的字母,我们可以发现只需经过原来路径上的第一个和最后一个(因为它们直接有边相连)就可以使得路径变短。这样,由于最多只有 8 8 8种字母,则最长的长度为 2 × 8 − 1 = 15 2\times 8-1=15 2×8−1=15。
则我们考虑定义
f
(
i
,
j
)
f(i,j)
f(i,j)为从
i
i
i出发,到达某个颜色是
j
j
j的位置的最短距离,这个可以用 BFS 做出来。这部分 比较简单 就不详细解释了。
考虑如何用 f ( i , j ) f(i,j) f(i,j)来表示 d ( i , j ) d(i,j) d(i,j):
- 只经过第一种类型的边(即 ∣ i − j ∣ ≤ 1 |i-j|\le 1 ∣i−j∣≤1时连上的边),这样这个距离就是 ∣ i − j ∣ |i-j| ∣i−j∣;
- 经过第二种类型的边,我们可以考虑通过某个颜色为 c c c中转点,这样这个距离就是 f ( i , c ) + 1 + f ( j , c ) f(i,c)+1+f(j,c) f(i,c)+1+f(j,c)。
综上:
d ( i , j ) = min ( ∣ i − j ∣ , min 1 ≤ c ≤ 8 { f ( i , c ) + f ( j , c ) + 1 } ) d(i,j)=\min(|i-j|,\min_{1\le c\le 8}\{f(i,c)+f(j,c)+1\}) d(i,j)=min(∣i−j∣,1≤c≤8min{f(i,c)+f(j,c)+1})
这样一来我们就可以在 O ( N 2 ) O(N^2) O(N2)的时间内求出 d ( i , j ) d(i,j) d(i,j)。
但是 N N N有 1 0 5 10^5 105,我们不能够直接暴力。
考虑计算一个新的值 g ( i , j ) g(i,j) g(i,j)表示从某个颜色为 i i i的节点到某个颜色为 j j j的节点的最短距离,这个值可以在做 BFS 时顺带着算了。
对于 ∣ i − j ∣ ≤ 15 |i-j|\le 15 ∣i−j∣≤15的情况,我们直接采用暴力。
但对于 ∣ i − j ∣ > 15 |i-j|>15 ∣i−j∣>15时,我们可以发现, f ( i , c ) f(i,c) f(i,c)不是等于 g ( s i , c ) g(s_i,c) g(si,c)就是等于 g ( s i , c ) + 1 g(s_i,c)+1 g(si,c)+1,又由于 c c c只有最多 8 8 8情况,所以我们可以将这个压成一个集合 S S S,若 S S S的第 c c c位为 0 0 0则表示 f ( i , c ) = g ( s i , c ) f(i,c)=g(s_i,c) f(i,c)=g(si,c),反之就表示 f ( i , c ) = g ( s i , c ) + 1 f(i,c)=g(s_i,c)+1 f(i,c)=g(si,c)+1。
那么我们可以再用 O ( N 2 2 2 N ) O(N^22^{2N}) O(N222N)的复杂度来做一个预处理。设 h ( i , j , S 1 , S 2 ) h(i,j,S_1,S_2) h(i,j,S1,S2)表示一个颜色为 i i i、状态为 S 1 S_1 S1和另外一个颜色为 j j j、状态为 S 2 S_2 S2时的点的合并时的最小结果。
那么我们在统计答案时直接调用我们计算出来的 h h h即可。
总时间复杂度为 O ( 8 N + 8 3 × 2 2 N + 8 N × 2 N ) O(8N+8^3\times 2^{2N}+8N\times 2^N) O(8N+83×22N+8N×2N)。
参考代码
#include <queue>
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
typedef long long ll;
const int Maxn = 100000;
const int Maxk = 8;
int N;
char s[Maxn + 5];
int f[Maxn + 5][Maxk + 5];
int g[Maxk + 5][Maxk + 5];
vector<int> p[Maxk + 5];
void BFS(int col) {
static bool vis[Maxn + Maxk + 5];
memset(vis, false, sizeof vis);
queue<int> q;
for(int j = 0; j < (int)p[col].size(); j++)
q.push(p[col][j]), vis[p[col][j]] = true, f[p[col][j]][col] = 0;
g[col][col] = 0, vis[N + col] = true;
while(!q.empty()) {
int u = q.front();
q.pop();
if(u != 1 && !vis[u - 1]) {
vis[u - 1] = true, q.push(u - 1);
f[u - 1][col] = f[u][col] + 1;
}
if(u != N && !vis[u + 1]) {
vis[u + 1] = true, q.push(u + 1);
f[u + 1][col] = f[u][col] + 1;
}
if(vis[N + (int)s[u] - 'a' + 1]) continue;
vis[N + (int)s[u] - 'a' + 1] = true;
g[s[u] - 'a' + 1][col] = f[u][col];
for(int i = 0; i < (int)p[s[u] - 'a' + 1].size(); i++) {
int v = p[s[u] - 'a' + 1][i];
if(vis[v]) continue;
f[v][col] = f[u][col] + 1;
vis[v] = true, q.push(v);
}
}
}
int cnt[Maxk + 5][(1 << Maxk) + 5];
int dis[Maxk + 5][Maxk + 5][(1 << Maxk) + 5][(1 << Maxk) + 5];
int st[Maxn + 5];
int main() {
#ifdef LOACL
freopen("in.txt", "r", stdin);
freopen("out.txt", "w", stdout);
#endif
scanf("%d", &N);
scanf("%s", s + 1);
for(int i = 1; i <= N; i++)
p[s[i] - 'a' + 1].push_back(i);
memset(f, 0x3f, sizeof f);
memset(g, 0x3f, sizeof g);
for(int i = 1; i <= Maxk; i++)
BFS(i);
memset(dis, 0x3f, sizeof dis);
for(int col1 = 1; col1 <= 8; col1++)
for(int col2 = 1; col2 <= col1; col2++)
for(int s1 = 0; s1 < (1 << 8); s1++)
for(int s2 = 0; s2 < (1 << 8); s2++) {
for(int col3 = 1; col3 <= 8; col3++)
dis[col1][col2][s1][s2] = min(dis[col1][col2][s1][s2],
g[col1][col3] + g[col2][col3] + ((s1 >> (col3 - 1)) & 1)
+ ((s2 >> (col3 - 1)) & 1) + 1);
dis[col2][col1][s2][s1] = dis[col1][col2][s1][s2];
}
int ans = 0;
ll sum = 0;
for(int i = 1; i <= N; i++) {
for(int j = max(1, i - 15); j < i; j++) {
int tmp = i - j;
for(int col = 1; col <= 8; col++)
tmp = min(tmp, f[i][col] + f[j][col] + 1);
if(ans < tmp) ans = tmp, sum = 0;
if(ans == tmp) sum++;
}
for(int col = 1; col <= 8; col++)
st[i] |= (f[i][col] - g[s[i] - 'a' + 1][col]) << (col - 1);
for(int col = 1; col <= 8; col++)
for(int j = 0; j < (1 << 8); j++)
if(cnt[col][j]) {
int tmp = dis[col][s[i] - 'a' + 1][j][st[i]];
if(ans < tmp) ans = tmp, sum = 0;
if(ans == tmp) sum += cnt[col][j];
}
if(i > 15) cnt[s[i - 15] - 'a' + 1][st[i - 15]]++;
}
printf("%d %lld\n", ans, sum);
return 0;
}