题目描述
帅帅经常跟同学玩一个矩阵取数游戏:对于一个给定的 n×m 的矩阵,矩阵中的每个元素 a i , j a_{i,j} ai,j均为非负整数。游戏规则如下:
1.每次取数时须从每行各取走一个元素,共 n 个。经过 m 次后取完矩阵内所有元素;
2.每次取走的各个元素只能是该元素所在行的行首或行尾;
3.每次取数都有一个得分值,为每行取数的得分之和,每行取数的得分 = 被取走的元素值
×
2
i
\times 2^i
×2i,其中 i 表示第 i 次取数(从 1 开始编号);
4.游戏结束总得分为 m 次取数得分之和。
帅帅想请你帮忙写一个程序,对于任意矩阵,可以求出取数后的最大得分。
输入格式
输入文件包括 n+1 行:
第一行为两个用空格隔开的整数 n 和 m。
第2∽n+1 行为n×m 矩阵,其中每行有 m 个用单个空格隔开的非负整数。
输出格式
输出文件仅包含1行,为一个整数,即输入矩阵取数后的最大得分。
解题思路
这题好像很难啊,我花了一个上午才做出来,首先一个重要的思想是把数组分割成小数组去遍历,递推公式是f[i,j]=max(a(i)+f[i+1,j],a(j)+f[i,j-1]),刚开始我想用递归和longlong去做,结果就报错TLE,复杂度太高了,没办法只能优化到非递归结构,选择用数组存储的形式,注意,这里我看到很多人选择用一个矩阵a去表示,a(i,j)=f(i,j),但是这样的话,f(j,i)是不存在的,就会把这个矩阵一半的空间浪费掉,这显然不符合我勤俭节约的作风,从matlab中将关于主对角线对称的矩阵降维到一维矩阵的函数得来的灵感,我选择将这个数组降到一维,这个数组被命名为d,d[(j+1)*j/2+i]表示f(i,j),这样的话,就能物尽其用将损耗降到最低,但是测试了一下,最后4个测试点通不过,发现是long long的精度放不下,于是我定义了一个big类,用两个long long组合在一起提高精度(你觉得我会承认我用vector和数组写的高精度被报tle了吗?)
Talking is cheap,show you my code.
代码
#include<iostream>
#define num 10000000000000000
using namespace std;
struct big {//定义一个高精度类
long long num1 = 0;
long long num2 = 0;
};
void insert(big & a, int t) {//将一个int数字赋给一个高精度类
a.num1 = t;
}
void print(big a) {//将高精度类打印输出
if (!a.num2) {
cout << a.num1;
}
else {
cout << a.num2 << a.num1;
}
cout << endl;
}
big add(big a, big b) {//将两个高精度类相加
big c;
long long t = a.num1 + b.num1;
c.num1 = t % num;
t /= num;
c.num2 = a.num2 + b.num2 + t;
return c;
}
big mult(big a, int n) {//将一个高精度类和一个不大于10的数相乘
big b;
long long t = a.num1;
t *= n;
b.num1 = t % num;
t /= num;
b.num2 = (a.num2)*n + t;
return b;
}
big max(big a, big b) {//将两个高精度进行比较
if (a.num2 > b.num2) {
return a;
}
if (b.num2 > a.num2) {
return b;
}
if (a.num1 > b.num1) {
return a;
}
if (b.num1 > a.num1) {
return b;
}
return a;
}
int n;
int m;
big f[81];
big ans;
big deep(int m) {//对每一行进行由深到浅遍历
big d[3300];
int l = 0;
while (l < m) {
int i = 0;
int j = i + l;
while (j < m) {
if (!l) {
d[(i*(i + 1)) / 2 + i] = mult(f[i], 2);
}
else {
d[(j*(j + 1)) / 2 + i] = max(add(mult(f[i], 2), mult(d[(j*(j + 1)) / 2 + i + 1], 2)), add(mult(f[j], 2), mult(d[(j*(j - 1)) / 2 + i], 2)));
}
i++;
j++;
}
l++;
}
return d[((m-1)*m)/2];
}
int main() {
cin >> n >> m;
for (int i = 0;i < n;i++) {
for (int j = 0;j < m;j++) {
int t;
cin >> t;
f[j].num1 = t;
f[j].num2 = 0;
}
ans = add(ans, deep(m));
}
print(ans);
return 0;
}