题目链接:
给一个n个点,n条边的有向图,求任意不同两点间最短路之和(不连通的算作-1)。(1<=n<=500000)
解题思路:
容易发现这是一个由基环内向树组成的森林,所以其中的强联通分量一定是环,且每棵树只有一个,那么我们可以先用tarjan求出环单独计算环的贡献,再缩点后处理树上的边。
讲真,这题的取模太恶心了……考试时硬生生WA了一半,以后要引以为戒。
#include<iostream>
#include<cstdio>
#include<cstring>
#include<string>
#include<algorithm>
#include<cmath>
#include<vector>
#include<queue>
#define ll long long
using namespace std;
int getint()
{
int i=0,f=1;char c;
for(c=getchar();(c<'0'||c>'9')&&c!='-';c=getchar());
if(c=='-')f=-1,c=getchar();
for(;c>='0'&&c<='9';c=getchar())i=(i<<3)+(i<<1)+c-'0';
return i*f;
}
const int N=500005,p=1e9+7;
int n;
int go[N],len[N];
int tot,first[N],nxt[N],to[N],w[N];
int idx,top,dfn[N],low[N],stk[N];
int num,id[N],size[N];
ll sum[N],val[N],ans;
vector<int>f[N];
bool exist[N];
inline void add(int x,int y,int z)
{
nxt[++tot]=first[x],first[x]=tot,to[tot]=y,w[tot]=z;
}
inline void tarjan(int u)
{
dfn[u]=low[u]=++idx;
stk[++top]=u,exist[u]=true;
int v=go[u];
if(!dfn[v])
{
tarjan(v);
low[u]=min(low[u],low[v]);
}
else
if(exist[v])low[u]=min(low[u],dfn[v]);
if(low[u]==dfn[u])
{
++num;
v=stk[top];
while(v!=u)
{
id[v]=num,size[num]++;
f[num].push_back(v);
top--,exist[v]=false,v=stk[top];
}
id[v]=num,size[num]++;
f[num].push_back(v);
top--,exist[v]=false;
}
}
inline void calc(int i)
{
reverse(f[i].begin(),f[i].end());
for(int j=0;j<size[i];j++)
val[f[i][0]]=(val[f[i][0]]+1ll*(size[i]-j)*len[f[i][j]]%p)%p;
for(int j=1;j<size[i];j++)
val[f[i][j]]=(val[f[i][j-1]]+sum[i]-1ll*size[i]%p*len[f[i][j-1]]%p)%p;
for(int j=0;j<size[i];j++)
val[f[i][j]]=(val[f[i][j]]-sum[i])%p;
}
inline void dfs(int u,int cnt)
{
exist[u]=true;
ans=((ans-1ll*(n-cnt)*size[u])%p+p)%p;
for(int e=first[u];e;e=nxt[e])
{
int v=to[e];
if(!exist[v])dfs(v,cnt+size[v]);
ans=(ans+1ll*size[v]*cnt%p*w[e]%p)%p;
size[u]+=size[v];
}
}
int main()
{
//freopen("road.in","r",stdin);
//freopen("road.out","w",stdout);
int i,y,z;
n=getint();
for(i=1;i<=n;i++)
go[i]=getint(),len[i]=getint();
for(i=1;i<=n;i++)
if(!dfn[i])tarjan(i);
for(i=1;i<=n;i++)
if(id[go[i]]==id[i])sum[id[i]]=(sum[id[i]]+len[i])%p;
else add(id[go[i]],id[i],len[i]);
for(i=1;i<=num;i++)
{
if(size[i]==1)continue;
ans=(ans+1ll*size[i]*(size[i]-1)/2%p*sum[i]%p)%p;
calc(i);
}
memset(exist,0,sizeof(exist));
for(i=1;i<=num;i++)
if(!exist[i])dfs(i,size[i]);
for(i=1;i<=n;i++)
if(id[go[i]]!=id[i])ans=(ans+1ll*size[id[i]]*val[go[i]]%p)%p;
cout<<(ans%p+p)%p;
return 0;
}