最短路+bitset+DP
个人觉得这题的思路非常高妙。
首先肯定是要建出最短路DAG,这个图上任意一条路径都对应一条原图的最短路。
如果一个点a在S到T的必经之路上,那就会有
S到a的方案数 * a到T的方案数 = S到T的方案数
这个东西显然是充要的,这是一个巧妙的转化。套用这个想法,这题要求选出两个点,那就只需
S到a的方案数 * a到T的方案数 + S到b的方案数 * b到T的方案数 = S到T的方案数
但注意它只是充分的。满足上式不一定就说明a和b合法,因为a可能能到达b,这就导致有的路径算重了,有的路径没算过。判掉即可。
这题卡内存,开1G才科学,不过随便写点玄学的东西也能卡过去。
#include<map>
#include<queue>
#include<cstdio>
#include<bitset>
#include<algorithm>
#define MOD 892857142857143
#define N 50005
using namespace std;
namespace runzhe2000
{
typedef long long ll;
map<ll,int> hash;
bitset<N> d[N];
vector<int> v[N];
const ll INF = 1ll << 60;
int n, m, S, T, last[N], ecnt, vis[N], r[N], timer;
ll dis[N], fS[N], fT[N];
bool cmp_dis(int a, int b){return dis[a] < dis[b];}
struct edge{int next, to, val;}e[N<<1];
void addedge(int a, int b, int c)
{
e[++ecnt] = (edge){last[a], b, c};
last[a] = ecnt;
}
struct item
{
ll dis; int id;
bool operator < (const item &that) const {return dis > that.dis;}
};
void dijk()
{
priority_queue<item> q; for(int i = 1; i <= n; i++) dis[i] = INF;
q.push((item){dis[S] = 0, S});
for(; !q.empty(); )
{
int x = q.top().id; q.pop();
if(vis[x]) continue; vis[x] = 1;
for(int i = last[x]; i; i = e[i].next)
{
int y = e[i].to;
if(dis[x] + e[i].val < dis[y])
q.push((item){dis[y] = dis[x] + e[i].val, y});
}
}
}
void dp1(ll *f)
{
f[S] = 1;
for(int p = 1; p <= n; p++)
{
int x = r[p];
for(int i = last[x]; i; i = e[i].next)
{
int y = e[i].to; if(dis[x] + e[i].val != dis[y]) continue;
(f[y] += f[x]) %= MOD;
}
}
}
void dp2(ll *f)
{
f[T] = 1;
for(int p = n; p >= 1; p--)
{
int x = r[p];
for(int i = last[x]; i; i = e[i].next)
{
int y = e[i].to; if(dis[x] + e[i].val != dis[y]) continue;
(f[x] += f[y]) %= MOD;
}
}
}
void dp3()
{
for(int p = n; p >= 1; p--)
{
int x = r[p]; d[x].set(x);
for(int i = last[x]; i; i = e[i].next)
{
int y = e[i].to; if(dis[x] + e[i].val != dis[y]) continue;
if(!fT[y]) continue;
d[x] |= d[y];
}
}
}
ll mul(ll a, ll b)
{
ll r = 0;
for(; b; b>>=1)
{
if(b&1)(r+=a)%=MOD;
(a+=a)%=MOD;
}
return r;
}
void main()
{
scanf("%d%d%d%d",&n,&m,&S,&T);
for(int i = 1; i <= m; i++)
{
int a, b, c; scanf("%d%d%d",&a,&b,&c);
addedge(a, b, c);
addedge(b, a, c);
}
dijk();
if(dis[T] == INF)
{
printf("%lld\n",1ll*n*(n-1)/2);
return;
}
for(int i = 1; i <= n; i++) r[i] = i; sort(r+1, r+1+n, cmp_dis);
dp1(fS);
dp2(fT);
dp3();
ll ans = 0;
for(int i = 1; i <= n; i++)
{
ll f = mul(fS[i], fT[i]), k = (fS[T] - f + MOD) % MOD;
int tmp = hash[k];
if(tmp)
{
for(int j = v[tmp].size()-1; ~j; j--)
{
int x = v[tmp][j];
if(d[x][i] || d[i][x]) continue;
ans++;
}
}
if(!(tmp = hash[f])) tmp = hash[f] = ++timer;
v[tmp].push_back(i);
}
printf("%lld\n",ans);
}
}
int main()
{
runzhe2000::main();
}