题意
给定一个长度为n的字符串,字符集大小为8。两个点i,j之间有权值为1的边需要满足以下条件至少一个
1. |i-j|=1
2. si=sj
求图的直径和多少个有序点对之间的最短路长度等于直径
N≤100000
分析
首先想到对于每个小写字母建一个虚点,然后再跑
有一个显然的性质,直径不超过15
然后有一个,
f[i][j]=min(|i−j|,dist[i][c]+dist[j][c]+1)
f
[
i
]
[
j
]
=
m
i
n
(
|
i
−
j
|
,
d
i
s
t
[
i
]
[
c
]
+
d
i
s
t
[
j
]
[
c
]
+
1
)
f[i][j] 表示i到j的最短距离
dist[i][c] 表示i到字母c的最短距离
我们可以预处理出dist[i][c],考虑怎么快速f[i][j]
再发现一点性质
g[s[i]][c]≤dist[i][c]≤g[s[i]][c]+1
g
[
s
[
i
]
]
[
c
]
≤
d
i
s
t
[
i
]
[
c
]
≤
g
[
s
[
i
]
]
[
c
]
+
1
g[i][j] 表示两个小写字母的最短距离
上式很容易得到,于是我们发现
对于一个i,你找前面的j,j的状态都可以根据每个颜色c,dist[j][c]来确定
对于一个颜色,dist[j][c]取0或者1,然后就可以有2^8个状态
这样就可以把前面的j给缩成256个状态
15个以内的暴力算
总的时间复杂度
O(28∗82∗n)
O
(
2
8
∗
8
2
∗
n
)
玄学跑过
代码
#include <bits/stdc++.h>
#define ll long long
#define pii pair<ll,ll>
#define bin(i) (1ll<<(i))
#define pb push_back
#define MP make_pair
using namespace std;
const ll N = 100010;
inline ll read()
{
ll p=0; ll f=1; char ch=getchar();
while(ch<'0' || ch>'9'){if(ch=='-') f=-1; ch=getchar();}
while(ch>='0' && ch<='9'){p=p*10+ch-'0'; ch=getchar();}
return p*f;
}
char s[N]; ll n; ll f[N][9]; queue<ll>q; bool vis[9];
void bfs(ll c)
{
memset(vis,0,sizeof(vis));
for(ll i=1;i<=n;i++) if(s[i] == c+'a') q.push(i),f[i][c] = 0; vis[c] = 1;
while(!q.empty())
{
ll x = q.front();
if(!vis[s[x]-'a'])
{
for(ll i=1;i<=n;i++) if(s[x] == s[i] && f[i][c] > n) f[i][c] = min(f[i][c] , f[x][c]+1),q.push(i);
vis[s[x]-'a']=1;
}
if(x-1>=1 && f[x-1][c] > f[x][c]){f[x-1][c] = min(f[x-1][c] , f[x][c]+1); q.push(x-1);}
if(x+1<=n && f[x+1][c] > f[x][c]){f[x+1][c] = min(f[x+1][c] , f[x][c]+1); q.push(x+1);}
q.pop();
}
// for(int i=1;i<=n;i++) printf("%lld ",f[i][c]); printf("\n");
}
ll dis[9][9],mask[N],cnt[9][260];
ll trans[9][260];
int main()
{
n = read(); scanf("%s",s+1);
memset(f,63,sizeof(f)); for(ll i=0;i<8;i++) bfs(i);
memset(dis,63,sizeof(dis));
for(ll i=1;i<=n;i++) for(ll j=0;j<8;j++) dis[s[i]-'a'][j] = min(dis[s[i]-'a'][j],f[i][j]);
for(ll i=1;i<=n;i++) for(ll j=0;j<8;j++) if(dis[s[i]-'a'][j] != f[i][j]) mask[i] |= bin(j);
ll ans = 0; ll maxx = 0;
memset(cnt,0,sizeof(cnt));
for(ll i=1;i<=n;i++)
{
ll t = i-15;
for(ll j=max(t,1ll);j<i;j++)
{
ll ss = abs(i-j);
for(ll k=0;k<8;k++) ss = min(ss , f[i][k] + f[j][k] + 1ll);
//printf("%lld %lld %lld\n",i,j,ss);
if(ss > maxx){maxx=ss; ans=1;}
else if(ss == maxx) ans++;
}
if(t>1) cnt[s[t-1]-'a'][mask[t-1]] ++;
for(ll j=0;j<8;j++)
for(ll k=0;k<256;k++) if(cnt[j][k])
{
ll ss = LLONG_MAX;
for(ll l=0;l<8;l++) ss=min(ss,f[i][l] + dis[j][l] + ((k&bin(l))>>l) + 1ll);
//printf("%lld %lld %lld\n",i,j,ss);
if(ss > maxx){maxx=ss; ans=cnt[j][k];}
else if(ss == maxx) ans+=cnt[j][k];
}
}
return printf("%lld %lld\n",maxx,ans),0;
}