建出最短路图
—(以下复制自官方题解)
定义
F(X)=
从
S
到
F(A)+F(B)=F(T)
A
和
对于条件
1
,我们可以使用数据结构进行优化(使用std::map即可),而对于条件
时间复杂度: O(nlogn+nmw) ,其中 w <script id="MathJax-Element-484" type="math/tex">w</script> 是位压的字长。
#include <cstdio>
#include <iostream>
#include <algorithm>
#include <queue>
#include <bitset>
#include <map>
#include <vector>
using namespace std;
typedef long long ll;
const int N=50010;
int n,m,cnt,G[N];
struct edge{
int t,nx,w;
}E[N<<1];
int S,T;
inline void addedge(int x,int y,int z){
E[++cnt].t=y; E[cnt].nx=G[x]; E[cnt].w=z; G[x]=cnt;
E[++cnt].t=x; E[cnt].nx=G[y]; E[cnt].w=z; G[y]=cnt;
}
queue<int> Q;
int vis[N],iT[N]; ll dis[N],rdis[N];
inline void spfa(ll *dis,int S){
for(int i=1;i<=n;i++)
dis[i]=1LL<<60,vis[i]=0;
Q.push(S); vis[S]=1; dis[S]=0;
while(!Q.empty()){
int x=Q.front(); Q.pop(); vis[x]=0;
for(int i=G[x];i;i=E[i].nx)
if(dis[E[i].t]>dis[x]+E[i].w){
dis[E[i].t]=dis[x]+E[i].w;
if(!vis[E[i].t])
vis[E[i].t]=1,Q.push(E[i].t);
}
}
}
namespace Grp{
int G[N],rG[N],du[N],rdu[N],cnt;
struct edge{
int s,t,nx;
}E[N<<2];
inline void addedge(int x,int y){
E[++cnt].t=y; E[cnt].s=x; E[cnt].nx=G[x]; G[x]=cnt; du[y]++;
E[++cnt].t=x; E[cnt].s=y; E[cnt].nx=rG[y]; rG[y]=cnt; rdu[x]++;
}
ll f[N],g[N];
queue<int> Q;
bitset<N> lnk[N],rlnk[N];
map<ll,bitset<N> > M;
inline ll solve(){
f[S]=1; g[T]=1;
Q.push(S);
while(!Q.empty()){
int x=Q.front(); Q.pop();
for(int i=G[x];i;i=E[i].nx){
f[E[i].t]+=f[x];
if(!--du[E[i].t]) Q.push(E[i].t);
}
}
Q.push(T);
while(!Q.empty()){
int x=Q.front(); Q.pop();
for(int i=rG[x];i;i=E[i].nx){
g[E[i].t]+=g[x];
if(!--rdu[E[i].t]) Q.push(E[i].t);
}
}
for(int i=1;i<=n;i++) lnk[i].set(i);
for(int i=1;i<n;i++)
for(int j=1;j<=cnt;j+=2)
lnk[E[j].s]|=lnk[E[j].t];
for(int i=1;i<n;i++)
for(int j=2;j<=cnt;j+=2)
rlnk[E[j].s]|=rlnk[E[j].t];
for(int i=1;i<=n;i++)
M[f[i]*g[i]].set(i);
ll ret=0;
for(int i=1;i<=n;i++){
if(!M.count(f[T]-f[i]*g[i])) continue;
ret+=(M[f[T]-f[i]*g[i]]&~lnk[i]&~rlnk[i]).count();
}
return ret>>1;
}
}
int main(){
freopen("6252.in","r",stdin);
freopen("6252.out","w",stdout);
scanf("%d%d%d%d",&n,&m,&S,&T);
for(int i=1,x,y,z;i<=m;i++)
scanf("%d%d%d",&x,&y,&z),addedge(x,y,z);
spfa(dis,S); spfa(rdis,T);
for(int i=1;i<=n;i++)
if(dis[i]+rdis[i]==dis[T]) iT[i]=1;
for(int x=1;x<=n;x++){
if(!iT[x]) continue;
for(int i=G[x];i;i=E[i].nx)
if(dis[E[i].t]==dis[x]+E[i].w && iT[E[i].t]) Grp::addedge(x,E[i].t);
}
printf("%lld\n",Grp::solve());
return 0;
}