题意: n ∗ m n∗m n∗m的矩阵,在 ( i , j ) ( i , j ) (i,j) 处有 a [ i ] [ j ] a[i][j] a[i][j]个钻石,且可以让所有钻石的单价上涨 b [ i ] [ j ] b[i][j] b[i][j],问从 ( 1 , 1 ) (1,1) (1,1)出发,每一次向下或向右移动到达 ( n , n ) (n,n) (n,n)后,所走过得位置的钻石最多可以卖多少钱?
解法一:设 d p [ i ] [ j ] [ k ] dp[i][j][k] dp[i][j][k]为走到 ( i , j ) (i,j) (i,j)这个点且持有的宝石数为 k k k的时,可以卖多少钱, d p [ i ] [ j ] [ k ] dp[i][j][k] dp[i][j][k]可以由 d p [ i − 1 ] [ j ] [ k ] dp[i - 1][j][k] dp[i−1][j][k]和 d p [ i ] [ j − 1 ] [ k ] dp[i][j-1][k] dp[i][j−1][k]转移而来,由于结果由两个变量:钻石的数量和钻石的单价控制,所以我们无法确定当前位置上的最优解,但是我们是可以确定当前位置上的非最优解的:数量少且单价低的一定不是最优解!所以我们确定位置( i , j ) i,j) i,j)上的可能最优解时,可以先将位置 ( i − 1 , j ) (i - 1,j) (i−1,j)和位置 ( i , j − 1 ) (i, j - 1) (i,j−1)上的可能最优解按数量进行升序排序,将单价低且数量少(对于当前的整个序列来说)的从这个序列中剔除,剩下的可能最优解与位置 ( i , j ) (i,j) (i,j)上的 a [ i ] [ j ] a[i][j] a[i][j]和 b [ i ] [ j ] b[i][j] b[i][j]构成位置 ( i , j ) (i,j) (i,j)上的最优解。
AC代码:
#include<bits/stdc++.h>
using namespace std;
const int N = 105;
typedef pair<long long, long long> PII;
typedef vector<PII> V;
typedef long long ll;
vector<PII> p[N][N];
PII tem[N * N * N]; //临时序列
int topz = 0;
ll a[N][N], b[N][N];
void check(PII x) {
while(topz && tem[topz].second <= x.second) topz --; //序列中的所有的first都要小于x的first,因为是first按升序排的,
//这句话也保证了序列中的second从大到小
if(topz == 0 || tem[topz].first < x.first) tem[++ topz] = x; //当x的first == 序列中的某个可能最优解的first的同时,
//x的second一定小于这个可能最优解的second,也就是说是一个非最优解
}
void merge(const V &x, const V &y, V &z) {
int sx = x.size(), sy = y.size();
int topx = 0, topy = 0;
topz = 0;
while(topx < sx && topy < sy)
check((x[topx].first < y[topy].first) ? x[topx ++] : y[topy ++]); //转移过来的可能最优解也是按first升序排列的
while(topx < sx) check(x[topx ++]);
while(topy < sy) check(y[topy ++]);
for(int i = 1; i <= topz; ++ i) z.push_back(tem[i]); //将筛完之后的可能最优解按宝石数量升序放入下一个状态中
}
void solve() {
int n;
scanf("%d", &n);
for(int i = 1; i <= n; ++ i)
for(int j = 1; j <= n; ++ j)
scanf("%lld", &a[i][j]), p[i][j].clear();;
for(int i = 1; i <= n; ++ i)
for(int j = 1; j <= n; ++ j)
scanf("%lld", &b[i][j]);
p[1][1].push_back(make_pair(a[1][1], b[1][1]));
for(int i = 1; i <= n; ++ i)
for(int j = 1; j <= n; ++ j) {
if(i == 1 && j == 1) continue;
else if(i == 1)
p[i][j].push_back(p[i][j - 1][0]); //第一行和第一列的转移都之由一种情况
else if(j == 1)
p[i][j].push_back(p[i - 1][j][0]);
else
merge(p[i - 1][j], p[i][j - 1], p[i][j]);
for(int k = 0; k < p[i][j].size(); ++ k) p[i][j][k].first += a[i][j], p[i][j][k].second += b[i][j];
}
ll MAX = 0;
for(int i = 0; i < p[n][n].size(); ++ i) {
MAX = max(MAX, p[n][n][i].first * p[n][n][i].second); //从所有没有被提出的可能最优解当中选出真正的最优解
}
printf("%lld\n", MAX);
}
int main() {
int t;
scanf("%d", &t);
while(t --) {
solve();
}
return 0;
}
解法二:对于每一个点,我们存下能到达它的点的所有的状态,并且这样一直推下去,每次给自己的状态排个序, a ∗ b a * b a∗b大的排前面,只取前若干个(这里取 100 100 100个),答案很大概率在这其中。
#include<bits/stdc++.h>
using namespace std;
const int N = 105;
typedef pair<long long, long long> PII;
typedef vector<PII> V;
typedef long long ll;
vector<PII> p[N][N];
int a[N][N], b[N][N];
void work(const V &x, const V &y, V &z) {
int sx = x.size(), sy = y.size();
int topx = 0, topy = 0;
while((topx < sx || topy < sy) && z.size() <= 100) {
if(topx < sx && topy < sy)
z.push_back((x[topx].first * x[topx].second > y[topy].first * y[topy].second) ? x[topx ++] : y[topy ++]);
else if(topx < sx) z.push_back(x[topx ++]);
else if(topy < sy) z.push_back(y[topy ++]);
}
}
void solve() {
int n;
scanf("%d", &n);
for(int i = 1; i <= n; ++ i)
for(int j = 1; j <= n; ++ j)
scanf("%d", &a[i][j]);
for(int i = 1; i <= n; ++ i)
for(int j = 1; j <= n; ++ j)
scanf("%d", &b[i][j]);
p[1][1].push_back(make_pair(a[1][1], b[1][1]));
for(int i = 1; i <= n; ++ i)
for(int j = 1; j <= n; ++ j) {
if(i == 1 && j == 1) continue;
if(i == 1) {
p[i][j].push_back(p[i][j - 1][0]);
}
else if(j == 1) {
p[i][j].push_back(p[i - 1][j][0]);
}
else {
work(p[i - 1][j], p[i][j - 1], p[i][j]);
}
for(int k = 0; k < p[i][j].size(); ++ k) p[i][j][k].first += a[i][j], p[i][j][k].second += b[i][j];
sort(p[i][j].begin(), p[i][j].end(), [&](PII a, PII b){return a.first * a.second > b.first * b.second;});
}
ll MAX = -1;
for(int i = 0; i < p[n][n].size(); ++ i) {
MAX = max(MAX, p[n][n][i].first * p[n][n][i].second);
}
printf("%lld\n", MAX);
for(int i = 1; i <= n; ++ i)
for(int j = 1; j <= n; ++ j) {
p[i][j].clear();
}
}
int main() {
int t;
scanf("%d", &t);
while(t --) {
solve();
}
return 0;
}