题意:给一张 n n n 点 m m m 边的简单无向图,求有多少个三元组 ( s , c , f ) (s,c,f) (s,c,f) ,满足存在一条从 s s s 到 f f f 经过 c c c 的简单路径。
n ≤ 1 0 5 , m ≤ 2 × 1 0 5 n\leq 10^5,m\leq 2\times 10^5 n≤105,m≤2×105
首先这个 “经过 c c c 的简单路径” ,即 c c c 取所有 s s s 到 f f f 的简单路径的交集,就是能到达的所有点双的并集,是圆方树的标志。具体讲解可以参考 PR的博客。
问题转换成了:求所有点对路径上的点双的并集的大小 − 2 -2 −2 (起始点) 之和。
建出圆方树,方点权值为其度数 (即点双的大小),圆点权值为 − 1 -1 −1 (点双边界的割点处被统计了两次,需要减掉;起始点本来就要减掉)
这样统计所有圆点路径上的权值之和,就可以做到 O ( n 2 ) O(n^2) O(n2)
考虑每个结点的贡献,计算子树大小(方点不算大小)瞎算一下就可以 O ( n ) O(n) O(n)
因为我的板子比较奇怪,需要把边去重,是 O ( n log n ) O(n\log n) O(nlogn) 的
#include <iostream>
#include <cstdio>
#include <cstring>
#include <cctype>
#include <vector>
#include <algorithm>
#define MAXN 100005
#define MAXM 400005
using namespace std;
inline int read()
{
int ans=0;
char c=getchar();
while (!isdigit(c)) c=getchar();
while (isdigit(c)) ans=(ans<<3)+(ans<<1)+(c^48),c=getchar();
return ans;
}
struct edge{int u,v;}e[MAXM];
int head[MAXN],nxt[MAXM],cnt=1;
inline void addnode(int u,int v)
{
e[++cnt]=(edge){u,v};
nxt[cnt]=head[u];
head[u]=cnt;
}
int n,m;
int dfn[MAXN],low[MAXN],tim;
int stk[MAXM],tp,vis[MAXM],bcc[MAXM],vcnt;
vector<int> rtt[MAXM];
void tarjan(int u)
{
dfn[u]=low[u]=++tim;
for (int i=head[u];i;i=nxt[i])
{
if (!vis[i>>1]&&!bcc[i>>1]) vis[(stk[++tp]=i)>>1]=1;
if (!dfn[e[i].v])
{
tarjan(e[i].v);
low[u]=min(low[u],low[e[i].v]);
if (dfn[u]==low[e[i].v])
{
rtt[u].push_back(++vcnt);
rtt[vcnt].push_back(u);
while (vis[i>>1])
{
int t=stk[tp--];
vis[t>>1]=0;
rtt[bcc[t>>1]=vcnt].push_back(e[t].v);
// rtt[e[t].v].push_back(vcnt);
}
}
}
else low[u]=min(low[u],dfn[e[i].v]);
}
}
int val[MAXM],siz[MAXM];
typedef long long ll;
ll ans;
void dfs(int u,int f,int tot)
{
siz[u]=(u<=n);
vis[u]=1;
for (int i=0;i<(int)rtt[u].size();i++)
{
int v=rtt[u][i];
if (v!=f)
{
dfs(v,u,tot);
ans+=(ll)siz[u]*siz[v]*val[u];
siz[u]+=siz[v];
}
}
ans+=(ll)siz[u]*(tot-siz[u])*val[u];
}
int main()
{
n=read(),m=read();
for (int i=1;i<=m;i++)
{
int u,v;
u=read(),v=read();
addnode(u,v),addnode(v,u);
}
int las=0;
vcnt=n;
for (int i=1;i<=n;i++) if (!dfn[i]) tarjan(i),siz[i]=tim-las,las=tim;
for (int i=1;i<=vcnt;i++)
{
sort(rtt[i].begin(),rtt[i].end());
rtt[i].erase(unique(rtt[i].begin(),rtt[i].end()),rtt[i].end());
}
for (int i=1;i<=n;i++) val[i]=-1;
for (int i=n+1;i<=vcnt;i++) val[i]=(int)rtt[i].size();
for (int i=1;i<=n;i++) if (!vis[i]) dfs(i,0,siz[i]);
printf("%lld\n",2*ans);
return 0;
}