前几天发的博客,后来想想当时方法有误,现已更正。
题意:给你一个n*n的矩阵,从左上到右下走k次,每走一次可以得到路径上的值,并且该值清0(就是只能得到1次),问k次后最大能多少。
思路:最大费用最大流,对于网络流问题,建图可以说是最关键的步骤。我大体画一下我的建图吧。(图中同一点拆点后少边,没画)
首先矩阵的点我是拆成俩点的,中间那路径上好记录它的值及这里的费用,然后其他如题意。
当然这题取过后要清0,我是这样处理的,如图1到10建两条边,一条容量为1,cost为该点的值,一条容量为k(源点或汇点控制好容量,这里其实大于等于k-1就可以),cost为0,当然要分别建反边。
吐槽:再也不相信题面了,数据范围应该远大于题面的范围,而且数组开小不是RE是TLE。WQNMLGB。一开始TLE,看了discuss有人说这,改大就过了。
#include<iostream>
#include<cstdio>
#include<cstring>
#include<queue>
#define maxn 1<<28
using namespace std;
int fst[60000],next[600000],node[600000],cost[600000],c[600000];
int a[1000][1000],f[600000],pre[60000],lu[60000];
int ednum;
int n,k;
int d[60000];
bool inq[60000];
void init()
{
ednum=-1;
memset(fst,-1,sizeof(fst));
for(int i=1; i<=n; i++)
{
for(int j=1; j<=n; j++)
{
scanf("%d",&a[i][j]);
next[++ednum]=fst[(i-1)*n+j];
fst[(i-1)*n+j]=ednum;
node[ednum]=n*n+(i-1)*n+j;
cost[ednum]=a[i][j];
c[ednum]=1;
next[++ednum]=fst[n*n+(i-1)*n+j];
fst[n*n+(i-1)*n+j]=ednum;
node[ednum]=(i-1)*n+j;
cost[ednum]=-a[i][j];
c[ednum]=0;
next[++ednum]=fst[(i-1)*n+j];
fst[(i-1)*n+j]=ednum;
node[ednum]=n*n+(i-1)*n+j;
cost[ednum]=0;
c[ednum]=k;
next[++ednum]=fst[n*n+(i-1)*n+j];
fst[n*n+(i-1)*n+j]=ednum;
node[ednum]=(i-1)*n+j;
cost[ednum]=-maxn;
c[ednum]=0;
if(i+1<=n)
{
next[++ednum]=fst[n*n+(i-1)*n+j];
fst[n*n+(i-1)*n+j]=ednum;
node[ednum]=i*n+j;
cost[ednum]=0;
c[ednum]=k;
next[++ednum]=fst[i*n+j];
fst[i*n+j]=ednum;
node[ednum]=n*n+(i-1)*n+j;
cost[ednum]=0;
c[ednum]=0;
}
if(j+1<=n)
{
next[++ednum]=fst[n*n+(i-1)*n+j];
fst[n*n+(i-1)*n+j]=ednum;
node[ednum]=(i-1)*n+j+1;
cost[ednum]=0;
c[ednum]=k;
next[++ednum]=fst[(i-1)*n+j+1];
fst[(i-1)*n+j+1]=ednum;
node[ednum]=n*n+(i-1)*n+j;
cost[ednum]=0;
c[ednum]=0;
}
}
}
next[++ednum]=fst[0];
fst[0]=ednum;
node[ednum]=1;
cost[ednum]=0;
c[ednum]=k;
ednum++;
next[++ednum]=fst[2*n*n];
fst[2*n*n]=ednum;
node[ednum]=2*n*n+1;
cost[ednum]=0;
c[ednum]=k;
ednum++;
}
int solve(int s,int t)
{
int ans=0;
memset(f,0,sizeof(f));
while(1)
{
queue<int>q;
memset(d,-1,sizeof(d));
memset(inq,0,sizeof(inq));
d[0]=0;
q.push(0);
inq[0]=1;
while(!q.empty())
{
int u=q.front();
q.pop();
inq[u]=0;
for(int i=fst[u];i!=-1;i=next[i])
{
int v=node[i];
if(c[i]>f[i]&&d[v]<d[u]+cost[i])
{
pre[v]=u;
lu[v]=i;
d[v]=d[u]+cost[i];
if(!inq[v])
{
q.push(v);
inq[v]=1;
}
}
}
}
if(d[t]==-1)break;
for(int i=t;i!=s;i=pre[i])
{
int v=lu[i];
f[v]++;
f[v^1]--;
}
ans+=d[t];
}
return ans;
}
int main()
{
scanf("%d%d",&n,&k);
init();
cout<<solve(0,2*n*n+1)<<endl;
return 0;
}