小 A 是社团里的工具人,有一天他的朋友给了他一个 n 个点,m 条边的正权连通无向图,要他计算所有点两两之间的最短路。
作为一个工具人,小 A 熟练掌握着 floyd 算法,设 w[i][j] 为原图中 (i,j) 之间的权值最小的边的权值,若没有边则 w[i][j]=无穷大。特别地,若 i=j,则 w[i][j]=0。
Floyd 的 C++ 实现如下:
for(int k=1;k<=p;k++)
for(int i=1;i<=n;i++)
for(int j=1;j<=n;j++)
w[i][j]=min(w[i][j],w[i][k]+w[k][j]);
当 p=n 时,该代码就是我们所熟知的 floyd,然而小 A 为了让代码跑的更快点,所以想减少 p 的值。
令 Di,j 为最小的非负整数 x,满足当 p=x 时,点 i 与点 j 之间的最短路被正确计算了。
现在你需要求 ∑i=1n∑j=1nDi,j,虽然答案不会很大,但为了显得本题像个计数题,你还是需要将答案对 998244353 取模后输出。
Input
第一行一个正整数 T(T≤30) 表示数据组数
对于每组数据:
第一行两个正整数 n,m(1≤n≤1000,m≤2000),表示点数和边数。
保证最多只有 5 组数据满足 max(n,m)>200
接下来 mmm 行,每行三个正整数 u,v,w 描述一条边权为 w 的边 (u,v),其中 1≤w≤10^9
Output
输出 TTT 行,第 iii 行一个非负整数表示第 iii 组数据的答案
Sample Input
1 4 4 1 2 1 2 3 1 3 4 1 4 1 1
Sample Output
6
思路:点太多,暴力floyd应该会TLE,所以改成dijkstra,求最短路上最大松弛点的最小值
#include<bits/stdc++.h>
#pragma GCC optimize(3)
using namespace std;
typedef long long ll;
inline ll max(ll a,ll b){
return a>b?a:b;
}
inline ll min(ll a,ll b){
return a<b?a:b;
}
const ll INF=1e16;
const int mod=998244353;
const int N=2e3+5;
int n,m;
struct Edge{
int v,next;
ll w;
}e[N<<1];
int head[N],tot=0;
void init(int n){
tot=0;
for(int i=1;i<=n;i++) head[i]=-1;
}
void add(int u,int v,ll w){
e[tot].v=v;
e[tot].w=w;
e[tot].next=head[u];
head[u]=tot++;
}
struct node{
int u;
ll w;
bool operator<(const node other)const {
if(w!=other.w) return w>other.w;
else return u>other.u;
}
};
bool vis[N];
ll dis[N];
int mp[N][N];
void Dj(int s){
priority_queue<node> q;
for(int i=1;i<=n;i++){
vis[i]=false;
dis[i]=INF;
}
dis[s]=0;
q.push(node{s,dis[s]});
while(!q.empty()){
node p=q.top();
q.pop();
int u=p.u;
if(vis[u]) continue;
vis[u]=true;
for(int i=head[u];~i;i=e[i].next){
int &v=e[i].v;
ll &w=e[i].w;
if(dis[v]>dis[u]+w){
dis[v]=dis[u]+w;
q.push(node{v,dis[v]});
if(dis[v]==w) mp[s][v]=0;
else mp[s][v]=max(mp[s][u],u);
}
else if(dis[v]==dis[u]+w){
q.push(node{v,dis[v]});
mp[s][v]=min(mp[s][v],max(mp[s][u],u));
}
}
}
}
ll cal(ll a){
if(a<mod) return a;
return a-mod;
}
int main(){
int T;
scanf("%d",&T);
while(T--){
scanf("%d%d",&n,&m);
init(n);
for(int i=1;i<=m;i++){
int u,v;
ll w;
scanf("%d%d%I64d",&u,&v,&w);
add(u,v,w);
add(v,u,w);
}
for(int i=1;i<=n;i++){
for(int j=1;j<=n;j++){
mp[i][j]=0;
}
}
for(int i=1;i<=n;i++) Dj(i);
ll ans=0;
for(int i=1;i<=n;i++){
for(int j=1;j<=n;j++){
ans=cal(ans+mp[i][j]);
}
}
printf("%I64d\n",ans);
}
return 0;
}