区间dp。
首先,由题意,每行取数独立,因此可以分别分析每一行的情况。
其次,基础的区间dp。
原问题:从序列第1个数到第m个数取出后的答案
子问题:从序列第i个数到第j个数取出后的答案
记dp[i][j]表示从序列中第i个数到第j个数取出后的答案,得到转移方程:
dp[i][j]=max(dp[i+1][j]+ar[i]* ,dp[i][j-1]+ar[j]* ),其中len为区间的长度。
简单说明一下为什么这里是而不是 。当区间长度为1时,意味着我们只取一个数,但是它是数列中最后取走的数,因为区间dp是从短到长考虑的,直接考虑完整的区间就不是dp了;当考虑区间长度为2时,新添加的数是倒数第二个被取走的,以此类推。
最后把n次dp的结果加起来输出即可。
Tips:最后输出有的会爆long long,采用__int128。
代码:
#include <iostream> #include <queue> #include <set> #include <map> #include <vector> #include <stack> #include <cmath> #include <algorithm> #include <cstdio> #include <cctype> #include <functional> #include <string> #include <cstring> #include <sstream> #include <deque> #define fir first #define sec second using namespace std; typedef __int128 ll; typedef pair<int,int> P; typedef pair<P,int> Q; const int inf1=2e9+9; const ll inf2=8e18+9; const ll mol=1e9+7; const int maxn=1e2+9; const ll maxx=1e12+9; int n,m; int ar[maxn]; ll dp[maxn][maxn]; ll Pow(ll x,ll t) { ll res=1; while(t) { if(t&1) res*=x; x*=x; t>>=1; } return res; } ll cal() { for(int len=1;len<=m;len++) { for(int i=1,j=i+len-1;j<=m;i++,j++) dp[i][j]=max(dp[i+1][j]+(ll)ar[i]*Pow(2,m-len+1),dp[i][j-1]+(ll)ar[j]*Pow(2,m-len+1)); } return dp[1][m]; } void print(ll x) { if(x==0) return; print(x/10); char num=x%10+'0'; putchar(num); } int main() { ll ans=0; cin>>n>>m; for(int i=0;i<n;i++) { for(int j=1;j<=m;j++) cin>>ar[j]; ans+=cal(); } if(ans==0) putchar('0'); else print(ans); putchar('\n'); }