题目描述
题解
先贴上jiry的题解,讲的还是不错的
http://jiry-2.blog.uoj.ac/blog/2242
然后说一下我觉得比较重要的地方
首先这道题怎么来考虑呢?
考虑将最短路为0和最短路非0来求
将两两最短路为0的点都缩成一个点,那么形成的这个新图就是一个有很多重边的图,并且点之间的最短路已知
要计算方案数,就是要分别计算缩成的点里有多少方案还有每条边都多少方案的乘积
边上有多少方案的做法正确性很显然
如何求缩成的点里有多少方案?
我们要再构造一个图,假设称为原图的辅助图
我们把原图最短路为0看成是辅助图上的一条边,非0看成是辅助图上没有边(所有边不计长度)
这样形成了许多个连通块
可以发现我们要保证每一个连通块必须连通,需要做的就是在每一个连通块里至少保留一棵生成树
据说是一个非常经典的容斥?
g的求法正确性很显然吧?
f的求法我也是理解了很久才明白…实际上,我们是在枚举1这个点所在辅助图中所在连通块的大小,然后将所有的点分成两部分,将两部分之间连边,又因为选点编号不同但是是等价的,所以还要乘上组合数
那么为什么要将0和非0分类来求?
可以从floyed中得到启发:最短路不就是选一些其它的边(走其它的路)来替代这条路么?如果d非0的话,可以走其它的路,这条路选另外一个长度;但是如果d=0不能这样做,因为这条路还有可能去组成其它的路
比较简单的理解就是距离为0的两点实际上是等价的
(⊙o⊙)…其实这道题说起来确实挺难理解的,我感觉我也没办法说得更明白了…大家海涵…
代码
#include<algorithm>
#include<iostream>
#include<cstring>
#include<cstdio>
#include<cmath>
using namespace std;
#define LL long long
#define Mod 998244353
#define N 405
int n;
LL k,d[N][N],dis[N][N],c[N][N],f[N],g[N],ans;
int fa[N],unit[N],cnt[N][N];
int find(int x)
{
if (x==fa[x]) return x;
fa[x]=find(fa[x]);
return fa[x];
}
LL fast_pow(LL a,int p)
{
LL ans=1LL;
for (;p;p>>=1,a=a*a%Mod)
if (p&1)
ans=ans*a%Mod;
return ans;
}
bool check(int i,int j)
{
for (int l=1;l<=n;++l)
if (l!=i&&l!=j)
if (dis[i][j]==dis[i][l]+dis[l][j]) return true;
return false;
}
int main()
{
ans=1LL;
scanf("%d%lld",&n,&k);
for (int i=1;i<=n;++i)
for (int j=1;j<=n;++j)
scanf("%lld",&d[i][j]);
//check no solution
for (int l=1;l<=n;++l)
{
if (d[l][l]!=0) {puts("0");return 0;}
for (int i=1;i<=n;++i)
{
if (d[l][i]!=d[i][l]) {puts("0");return 0;}
if (d[l][i]>k) {puts("0");return 0;}
for (int j=1;j<=n;++j)
if (d[i][j]>d[i][l]+d[l][j]) {puts("0");return 0;}
}
}
//f,g
for (int i=0;i<=n;++i) c[i][0]=1LL;
for (int i=1;i<=n;++i)
for (int j=1;j<=n;++j)
c[i][j]=(c[i-1][j-1]+c[i-1][j])%Mod;
for (int i=1;i<=n;++i) g[i]=fast_pow(k+1LL,(int)c[i][2])%Mod;
for (int i=1;i<=n;++i)
{
f[i]=g[i];
for (int j=1;j<=i-1;++j)
f[i]=(f[i]-f[j]*g[i-j]%Mod*c[i-1][j-1]%Mod*fast_pow(k,j*(i-j))%Mod)%Mod;
f[i]=(f[i]%Mod+Mod)%Mod;
}
//unit
for (int i=1;i<=n;++i) fa[i]=i;
for (int i=1;i<=n;++i)
for (int j=i+1;j<=n;++j)
if (!d[i][j])
fa[find(i)]=find(j);
for (int i=1;i<=n;++i) ++unit[find(i)];
for (int i=1;i<=n;++i)
if (unit[i]) ans=ans*f[unit[i]]%Mod;
//ununit
memset(dis,127,sizeof(dis));
for (int i=1;i<=n;++i)
for (int j=i+1;j<=n;++j)
{
int fi=find(i),fj=find(j);
if (fi>fj) swap(fi,fj);
++cnt[fi][fj];
dis[fi][fj]=dis[fj][fi]=min(dis[fi][fj],d[i][j]);
}
for (int i=1;i<=n;++i)
for (int j=i+1;j<=n;++j)
if (cnt[i][j])
{
if (check(i,j)) ans=ans*fast_pow(k-dis[i][j]+1,cnt[i][j])%Mod;
else ans=ans*(fast_pow(k-dis[i][j]+1,cnt[i][j])-fast_pow(k-dis[i][j],cnt[i][j]))%Mod;
}
ans=(ans%Mod+Mod)%Mod;
printf("%lld\n",ans);
}