题目大意
给定一个长度为n的字符串,字符集大小为8。两个点i,j之间有权值为1的边当且仅当满足以下条件之一
1. |i-j|=1
2. si=sj
求图的直径和多少个有序点对之间的最短路长度等于直径
2≤n≤100000
分析
好题
设
dist(i,j)
表示点i到点j的最短距离,
Dist(i,c)
表示点i到字符c的最短距离。
易得
dist(i,j)=min(|i−j|,Dist(i,c)+Dist(j,c)+1)
由于字符集大小为8,观察上式,容易证明最短路长度不超过15。于是对于
|i−j|>15
的点对,我们发现
dist(i,j)=min(Dist(i,c)+Dist(j,c)+1)
首先我们可以枚举c,然后通过bfs得到数组
Dist[][c]
。
接下来设
D(p,q)
表示字符p到字符q的最短距离,易得对于任意i,有
D(si,c)≤Dist(i,c)≤D(si,c)+1
那么再设
Mask(i)
,这是个二进制状态,第c位表示
Dist(i,c)−D(si,c)
那么
Dist(i,c)+Dist(j,c)+1=Dist(i,c)+D(sj,c)+Mask(j)c+1
我们在计算
|i−j|>15
的答案时,枚举i,同时维护一个数组
f[c][mask]
,表示字符是c,Mask是mask的位置j有多少个。然后可以枚举c和mask并计算距离。
|i−j|≤15
的点对直接枚举即可
时间复杂度 O(n∗82∗28)
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <cmath>
#define min(a,b) ((a)<(b)?(a):(b))
#define max(a,b) ((a)>(b)?(a):(b))
using namespace std;
const int N=1e5+30,M=256,C=8,mx=15,INF=1e8;
typedef long long LL;
int n,Dist[N][C],D[C][C],Mask[N],f[C][M],ans,Data[N],tot;
LL sum;
char S[N];
void bfs(int c)
{
Data[tot=1]=n+c; Dist[n+c][c]=-1;
for (int i=1,j,x;i<=tot;i++)
{
x=Data[i];
if (x<n)
{
if (x>0 && Dist[x-1][c]>INF)
{
Dist[x-1][c]=Dist[x][c]+1; Data[++tot]=x-1;
if (Dist[S[x-1]+n][c]>INF)
{
Dist[S[x-1]+n][c]=Dist[x][c]+1; Data[++tot]=S[x-1]+n;
}
}
if (x<n-1 && Dist[x+1][c]>INF)
{
Dist[x+1][c]=Dist[x][c]+1; Data[++tot]=x+1;
if (Dist[S[x+1]+n][c]>INF)
{
Dist[S[x+1]+n][c]=Dist[x][c]+1; Data[++tot]=S[x+1]+n;
}
}
}else
{
for (j=0;j<n;j++) if (S[j]+n==x && Dist[j][c]>INF)
{
Dist[j][c]=Dist[x][c]+1; Data[++tot]=j;
}
}
}
}
int main()
{
scanf("%d%s",&n,S);
for (int i=0;i<n;i++) S[i]-='a';
memset(Dist,42,sizeof(Dist));
memset(D,42,sizeof(D));
for (int i=0;i<C;i++) bfs(i);
for (int i=0;i<n;i++) for (int j=0;j<C;j++) D[S[i]][j]=min(D[S[i]][j],Dist[i][j]);
for (int i=0;i<n;i++) for (int j=C-1;j>=0;j--) Mask[i]=(Mask[i]<<1)|Dist[i][j]-D[S[i]][j];
for (int i=mx;i<n;i++) f[S[i]][Mask[i]]++;
for (int i=0,dist;i<n;i++)
{
if (i+mx<n) f[S[i+mx]][Mask[i+mx]]--;
for (int c=0;c<C;c++)
{
for (int mask=M-1;mask>=0;mask--) if (f[c][mask]>0)
{
dist=mx;
for (int c_=0;c_<C;c_++) dist=min(dist,Dist[i][c_]+D[c][c_]+((mask&(1<<c_))>0)+1);
if (dist>ans) ans=dist,sum=0;
if (dist==ans) sum+=f[c][mask];
}
}
if (i>=mx) f[S[i-mx]][Mask[i-mx]]++;
for (int j=max(0,i-mx);j<=min(n-1,i+mx);j++)
{
dist=abs(i-j);
for (int c_=0;c_<C;c_++) dist=min(dist,Dist[i][c_]+Dist[j][c_]+1);
if (dist>ans) ans=dist,sum=0;
if (dist==ans) sum++;
}
}
printf("%d %lld\n",ans,sum>>1);
return 0;
}