官方题解给出的解法是先预处理出行和列取0到k的最大值,然后从0到k枚举,找到最大的r[i]+c[i]-i*(k-i)*p。。。
#include <iostream>
#include <sstream>
#include <algorithm>
#include <vector>
#include <queue>
#include <stack>
#include <map>
#include <set>
#include <bitset>
#include <cstdio>
#include <cstring>
#include <cstdlib>
#include <cmath>
#include <climits>
#define maxn 1000005
#define eps 1e-6
#define mod 10007
#define INF 99999999
#define lowbit(x) (x&(-x))
//#define lson o<<1, L, mid
//#define rson o<<1 | 1, mid+1, R
typedef long long LL;
using namespace std;
struct node
{
LL x;
bool operator < (const node &a) const {
return a.x>x;
}
}tmp;
priority_queue<node> q1, q2;
LL R[1005], C[1005];
LL r[maxn], c[maxn];
int main(void)
{
int n, m, k, p;
int i, j;
LL ans, a;
scanf("%d%d%d%d", &n, &m, &k, &p);
for(i = 1; i <= n; i++)
for(j = 1; j <= m; j++) {
scanf("%I64d", &a);
R[i] += a;
C[j] += a;
}
for(i = 1; i <= n; i++) {
tmp.x = R[i];
q1.push(tmp);
}
for(i = 1; i <= m; i++) {
tmp.x = C[i];
q2.push(tmp);
}
for(i = 1; i <= k; i++) {
tmp = q1.top();
q1.pop();
r[i] = r[i-1] + tmp.x;
tmp.x -= m*p;
q1.push(tmp);
}
for(i = 1; i <= k; i++) {
tmp = q2.top();
q2.pop();
c[i] = c[i-1] + tmp.x;
tmp.x -= n*p;
q2.push(tmp);
}
ans = max(r[k], c[k]);
for(i = 1; i <= k; i++)
ans = max(ans, r[i] + c[k-i] - 1ll*(k-i)*i*p);
printf("%I64d\n", ans);
return 0;
}