题目描述
帅帅经常跟同学玩一个矩阵取数游戏:对于一个给定的 𝑛×𝑚n×m 的矩阵,矩阵中的每个元素 𝑎𝑖,𝑗ai,j 均为非负整数。游戏规则如下:
- 每次取数时须从每行各取走一个元素,共 𝑛n 个。经过 𝑚m 次后取完矩阵内所有元素;
- 每次取走的各个元素只能是该元素所在行的行首或行尾;
- 每次取数都有一个得分值,为每行取数的得分之和,每行取数的得分 = 被取走的元素值 ×2𝑖×2i,其中 𝑖i 表示第 𝑖i 次取数(从 11 开始编号);
- 游戏结束总得分为 𝑚m 次取数得分之和。
帅帅想请你帮忙写一个程序,对于任意矩阵,可以求出取数后的最大得分。
输入格式
输入文件包括 𝑛+1n+1 行:
第一行为两个用空格隔开的整数 𝑛n 和 𝑚m。
第 2∼𝑛+12∼n+1 行为 𝑛×𝑚n×m 矩阵,其中每行有 𝑚m 个用单个空格隔开的非负整数。
输出格式
输出文件仅包含 11 行,为一个整数,即输入矩阵取数后的最大得分。
输入输出样例
输入
2 3
1 2 3
3 4 2
输出
82
说明/提示
【数据范围】
对于 60%60% 的数据,满足 1≤𝑛,𝑚≤301≤n,m≤30,答案不超过 10161016。
对于 100%100% 的数据,满足 1≤𝑛,𝑚≤801≤n,m≤80,0≤𝑎𝑖,𝑗≤10000≤ai,j≤1000。
【题目来源】
NOIP 2007 提高第三题。
#include<iostream>
#include<cstring>
#include<algorithm>
#include<cstdio>
#define A 1000000000000000
using namespace std;
struct bint{
long long s[10];
bint(long long num = 0){
*this=num;
}
bint operator = (long long num){
memset(s, 0, sizeof(s));
s[0]=num;
return *this;
}
bint operator + (const bint& b) const {
bint c;
unsigned long long g = 0;
for (int i=0; i<9 ; i++){
unsigned long long x = g;
x += (unsigned long long)s[i]+b.s[i];
c.s[i]=x%A;
g=x/A;
}
return c;
}
bint operator* (const bint& b) const {
bint c;
unsigned long long g=0;
for (int i=0; i<9; i++){
unsigned long long x=g;
for (int j=0; j<=i; j++){
int k=i-j;
x+=(unsigned long long)s[k]*b.s[j];
}
c.s[i]=x%A;
g=x/A;
}
return c;
}
bool operator < (const bint& b) const {
for (int i=9; i>=0; i--){
if (s[i]<b.s[i]) return 1;
if (s[i]>b.s[i]) return 0;
}
return 0;
}
void print(){
char buf[200];
for (int i=9; i>=0; i--){
sprintf(buf+(9-i)*15, "%015lld", s[i]);
}
bool flag=0;
for (int i=0; i<150; i++){
if (buf[i]>'0') flag=1;
if (flag) printf("%c", buf[i]);
}
if (!flag) printf("0");
}
};
long long a[100]; bint dp[100][100];
bint ans;
bint two[82];
inline void ini(){
two[0]=1;
for (int i=1; i<=80; i++){
two[i]=two[i-1]*2;
}
}
bint at[100][100];
bool used[100][100];
inline bint multi(int i, int p){
if (used[i][p]) return at[i][p];
used[i][p]=true;
return at[i][p]=(bint)a[i]*two[p];
}
int main(){
freopen("1005.cpp", "r", stdin);
ini();
int n, m;
cin>>n>>m;
ans=0;
for (int w=0; w<n; w++){
memset(dp, 0, sizeof(dp));
memset(used, false, sizeof(used));
for (int i=1; i<=m; i++) scanf("%d", a+i);
bint anst=0;
for (int t=0; t<m; t++){
for (int i=1; i+t<=m; i++){
int j=i+t;
int p=m-t;
bint s = dp[i+1][j]+multi(i,p);
bint t = dp[i][j-1]+multi(j,p);
if (s<t) dp[i][j]=t;
else dp[i][j]=s;
}
}
ans=ans+dp[1][m];
}
ans.print();
cout<<endl;
}