经典题目改版,直接求完最短路上最大流就行了。
#include <iostream>
#include <cstring>
#include <algorithm>
#include <cmath>
#include <cstdio>
#include <queue>
using namespace std;
typedef long long sint;
sint INF=(1LL<<42);
#define maxn 222222
int getint()
{
char c;int res;
while(c=getchar(),c<'0'||c>'9');
res=c-'0';
while(c=getchar(),c>='0'&&c<='9')
res=res*10+c-'0';
return res;
}
int st,ed,en;
struct node
{
int v,next;
sint c;
}e[maxn];
int dis[maxn],first[maxn];
int n,m;
void add(int a,int b,sint c)
{
en++;
e[en].v=b;
e[en].c=c;
e[en].next=first[a];
first[a]=en;
en++;
e[en].v=a;
e[en].c=0;
e[en].next=first[b];
first[b]=en;
}
bool bfs()
{
queue<int>q;
memset(dis,-1,sizeof(dis));
q.push(st);
dis[st]=0;
while(!q.empty())
{
int u=q.front();
q.pop();
for(int i=first[u];i!=-1;i=e[i].next)
{
int v=e[i].v;
if(dis[v]!=-1||e[i].c==0) continue;
dis[v]=dis[u]+1;
if(v==ed)
{
return true;
}
q.push(v);
}
}
return false;
}
sint dfs(int now,sint mx)
{
if(now==ed||mx==0) return mx;
sint flow=0,tmp;
for(int i=first[now];i!=-1;i=e[i].next)
{
int v=e[i].v;
if(dis[v]!=dis[now]+1||e[i].c==0) continue;
tmp=dfs(v,min(mx,e[i].c));
if(tmp)
{
mx-=tmp;
e[i].c-=tmp;
e[i^1].c+=tmp;
flow+=tmp;
if(!mx) break;
}
}
if(!flow) dis[now]=-1;
return flow;
}
sint dinic()
{
sint maxflow=0,tmp;
while(bfs())
{
while(tmp=dfs(st,INF)) maxflow+=tmp;
}
return maxflow;
}
void init()
{
en=-1;
memset(first,-1,sizeof(first));
st=1;
ed=n*2;
}
sint val2[maxn],dis2[maxn];
int cnt,first2[maxn],to2[maxn],next2[maxn];
bool vis[maxn];
void build(int a,int b,sint c)
{
cnt++;
to2[cnt]=b;val2[cnt]=c;
next2[cnt]=first2[a];
first2[a]=cnt;
}
void spfa()
{
queue<int>q;
for(int i=0;i<=n;i++) dis2[i]=INF;
q.push(1);dis2[1]=0;
while(!q.empty())
{
int u=q.front();
q.pop();
for(int i=first2[u];i;i=next2[i])
{
int v=to2[i];
if(dis2[v]>dis2[u]+val2[i])
{
dis2[v]=dis2[u]+val2[i];
if(!vis[v])
{
vis[v]=1;
q.push(v);
}
}
}
vis[u]=0;
}
}
void getg()
{
for(int i=1;i<=n;i++)
{
for(int j=first2[i];j;j=next2[j])
{
int v=to2[j];
if(dis2[i]+val2[j]==dis2[v])
{
add(i+n,v,INF);
}
}
}
}
int main()
{
scanf("%d%d",&n,&m);
init();
int a,b,c;
for(int i=1;i<=m;i++)
{
a=getint();
b=getint();
c=getint();
build(a,b,c);
build(b,a,c);
}
spfa();
for(int i=1;i<=n;i++)
{
a=getint();
if(i!=1&&i!=n)
{
add(i,i+n,a);
}
}
add(1,1+n,INF);
add(n,n*2,INF);
getg();
printf("%lld",dinic());
return 0;
}