J-xay loves Floyd
-
如果 d i , j = w i , j \text d_{i,j}=\text w_{i,j} di,j=wi,j,那么按照题意中的算法仍然能得到正确的结果。此时记 can i , j = 1 \text{can}_{i,j}=1 cani,j=1。
-
如果存在 v v v,使得① can i , v = 1 \text{can}_{i,v}=1 cani,v=1② can v , j = 1 \text{can}_{v,j}=1 canv,j=1③ v v v在 i i i到 j j j的任意一条最短路上,那么 can i , j = 1 \text{can}_{i,j}=1 cani,j=1。
直接这么算
c
a
n
[
i
]
[
j
]
can[i][j]
can[i][j]复杂度太高,我们注意到
c
a
n
[
i
]
[
∗
]
can[i][*]
can[i][∗],
c
a
n
[
∗
]
[
j
]
can[*][j]
can[∗][j]的运算本质上是集合求交,可以利用bitset维护。
将
c
a
n
[
i
]
[
∗
]
can[i][*]
can[i][∗]记为bitset<N> fr[i]
,
c
a
n
[
∗
]
[
j
]
can[*][j]
can[∗][j]记为bitset<N>to[j]
同时,枚举 s s s,则 s s s到 j j j的所有最短路经过的点集 pot j \text{pot}_j potj也可以通过bitset维护,具体做法是每次枚举一个 s s s,就重新把顶点按照到 s s s的最短路长度排序,从小到大计算 pot j \text{pot}_j potj。如果 d s , k + w k , j = d s , j \text d_{s,k}+\text w_{k,j}=\text d_{s,j} ds,k+wk,j=ds,j,则 pot j ∣ = pot k \text{pot}_j|=\text{pot}_k potj∣=potk
时间复杂度 O ( n m log m + n 2 w w ) O(nm\log m+\frac{n^2w}{w}) O(nmlogm+wn2w)
Code
#include<bits/stdc++.h>
using namespace std;
using ll=long long;
template <class T=int> T rd()
{
T res=0;T fg=1;
char ch=getchar();
while(!isdigit(ch)) {if(ch=='-') fg=-1;ch=getchar();}
while( isdigit(ch)) res=(res<<1)+(res<<3)+(ch^48),ch=getchar();
return res*fg;
}
const int N=2005,M=5005;
int h[N],e[M],ne[M],w[M],idx;
void add(int a,int b,int c){e[idx]=b,ne[idx]=h[a],w[idx]=c,h[a]=idx++;}
int d[N][N];
int n,m;
bool st[N];
bitset<N> pot[N],fr[N],to[N];
void dij(int s,int d[])
{
memset(st,0,sizeof st);
priority_queue<pair<int,int>,vector<pair<int,int>>,greater<pair<int,int>>> q;
q.push({d[s]=0,s});
while(q.size())
{
int u=q.top().second;q.pop();
if(st[u]) continue;
st[u]=1;
for(int i=h[u];i!=-1;i=ne[i])
{
int v=e[i];
if(d[v]>d[u]+w[i])
{
d[v]=d[u]+w[i];
q.push({d[v],v});
}
}
}
for(int i=h[s];i!=-1;i=ne[i])
{
int v=e[i];
if(w[i]==d[v]) fr[s][v]=to[v][s]=1;
}
}
int solve(int s)
{
static int id[N];
for(int i=1;i<=n;i++)
{
pot[i].reset();
pot[i].set(i);
}
for(int i=1;i<=n;i++) id[i]=i;
sort(id+1,id+1+n,[&](const int i,const int j){return d[s][i]<d[s][j];});
for(int i=1;i<=n;i++)
{
int u=id[i];
for(int i=h[u];i!=-1;i=ne[i])
{
int v=e[i];
if(d[s][u]+w[i]==d[s][v]) pot[v]|=pot[u];
}
}
for(int i=1;i<=n;i++)
if(d[s][i]==0x3f3f3f3f||(pot[i]&fr[s]&to[i]).count()) fr[s][i]=to[i][s]=1;
return fr[s].count();
}
int main()
{
n=rd(),m=rd();
memset(d,0x3f,sizeof d);
memset(h,0xff,sizeof h);
for(int i=1;i<=n;i++)
{
fr[i].reset();fr[i][i]=1;
to[i].reset();to[i][i]=1;
}
while(m--)
{
int u=rd(),v=rd(),c=rd();
add(u,v,c);
}
for(int i=1;i<=n;i++) dij(i,d[i]);
int ans=0;
for(int i=1;i<=n;i++) ans+=solve(i);
printf("%d\n",ans);
}