小记: dp算法一般有明确的状态转移方程,然而当这种方程太过复杂,一个dp[i][j]依赖于很多个dp[i-1][k]的时候,就会影响dp算法的复杂度,也就是说计算每一个dp[i][j]需要的不只O(1)的时间,而是O(n)或者更高的时间。这时候就需要一个桥梁用于沟通上一次的dp数组和这一次的dp数组,让dp[i][j]不依赖于dp[i-1][k]而是依赖于这个通过预处理得到的桥梁,来实现在O(1)的时间内算出当前dp值。(大多数时候前缀和就是一个很合适的桥梁)
题意:
给你n行3列的矩阵,告诉你每列和是多少(c0,c1,c2),每行和是多少(num0,num1,num2…),问你能构造出多少种这样的矩阵,(以上所有数字小于等于125,答案对1017取模)
输入
3 //有三行
1 2 3 //c0,c1,c2
2 3 4 //n0,n1,n2…
输出
0
思路:
这个是dp的题,定义dp[x][a][b][c]是在x行第一列剩下a可以用,第二列剩下b可以用,第三列剩下c的时候有多少种可能的情况。那么最后的答案就是dp[n-1][c0][c1][c2]
那么dp[x][a][b][c] j就等于所有的dp[x-1][i][j][k]相加,当a>=i&&b>=j&&c>=k&&
(a+b+c-i-j-k) == num[x]的时候,
所以代码应该这样写
ll dp[128][128][128][128];
for (int i = 0; i < n; i++)
for (int j = 0; j <= c0 ; j++)
for (int k = 0; k <= c1; k++)
for (int l = 0; l <= c2 ; k++) {
dp[i][j][k][l] = 0;
for (int o = j - num[i]; o <= j; o++)
for (int p = k - num[i]; p <= k; p++)
for (int q = l - num[i]; q <= l; q++) {
if (j + k + l - o - p - q != num[i])
continue;
dp[i][j][k][l] += dp[i - 1][o][p][q];
}
}
不过很显然,当j和k确定的时候,l就确定了,所以可以给dp降一个维度,
代码变成这样
ll dp[128][128][128];
for (int i = 0; i < n; i++)
for (int j = 0; j <= c0 && j <= num[i]; j++)
for (int k = 0; k <= c1 && j + k <= num[i]; k++) {
dp[i][j][k] = 0;
for (int o = j - num[i]; o <= j; o++)
for (int p = k - num[i]; p <= k && j + k - o - p <= num[i]; p++) {//最终差值要等于num[i],由于有l,所以只需要小于等于
dp[i][j][k] += dp[i - 1][o][p];
}
}
而且很明显dp数组永远只用了两个128*128的大小,所以做个轮动数组,变成这样
int so() {
ll dp[2][128][128];
ll(*last)[128] = dp[1], (*now)[128] = dp[0];
for (int i = 0; i < n; i++, swap(last, now))
for (int j = 0; j <= c0 && j <= num[i]; j++)
for (int k = 0; k <= c1 && j + k <= num[i]; k++) {
now[j][k] = 0;
for (int o = j - num[i]; o <= j; o++)
for (int p = k - num[i]; p <= k && j + k - o - p <= num[i]; p++) {//最终差值要等于num[i],由于有l,所以只需要小于等于
now[j][k] += last[o][p];
}
}
}
这样时间复杂度是O(n5)明显无法通过。
但是再观察,now[i][j]的值依赖于last的值,如果你把last数组看成一个矩阵,那么对于里面的两层循环,now[j][k]实际上是last矩阵上的一个三角形区域的值的和(至于为什么是三角形区域,希望读者自己在草稿本上画一下就明白了)
那么我们就可以通过一些预处理的方式,提前算出关于这个三角形一些需要的数据,就可以在O(1)的时间算出now[j][k]。
如果我们提前输出last数组的前缀和,那么我们就可以在O(n)的时间内算出now[j][k],像这样
图中每一个横条表示一个前缀和。这样整体的复杂度就将到了O(n4),总的循环次数就是2*108,很明显就是不一定能通过的,我们需要进一步降低复杂度。
观察一下,如果要求的和是一个四边形,那么now[j][k]和now[j-1][k]是有明确的关系的
那么now[j][k]等于now[j-1][k]加上红色部分,减去蓝色部分,那么三角形也有类似的关系
为什么是now[j-1][k+1]呢,因为只有这两个三角形才有公共的斜边,这样才能避免去计算斜边上的前缀和。这样now[j][k]等于now[j-1][k+1]加上红色部分,减去蓝色部分
这里有两个特色情况,j=0或者k=num[i]时,now[j-1][k+1]不存在,所以只能用O(n)个前缀和来算,但这种情况只有O(n)个,其他的情况可以在O(1)的时间内算出来,所以这时候总的时间复杂度是O(n3),循环次数约是2*106,这时候就能通过。
代码写出来是这个样子
//++i比i++更快,应该用++i的
#include <iostream>
#include <cstdio>
#include <iomanip>
#include <string>
#include <cstdlib>
#include <cstring>
#include <queue>
#include <set>
#include <vector>
#include <map>
#include <algorithm>
#include <cmath>
#include <stack>
#define INF 0x3f3f3f3f
#define LINF 0x3f3f3f3f3f3f3f3f
#define ll long long
#define ull unsigned long long
#define uint unsigned int
using namespace std;
int n, c0,c1,c2, num[128];
ll dp[2][128][128], bdp[2][128][128];//bridge_dp
ll mod = (ll)(1e17);
ll get_sum(int cr, int itr, int beg, int end) {
if (cr == 0) {
beg = max(beg, 0);
return beg ? (bdp[0][itr][end] - bdp[0][itr][beg - 1] + mod) % mod : bdp[0][itr][end];
}
beg = max(beg, 0);
return beg ? (bdp[1][end][itr] - bdp[1][beg - 1][itr] + mod) % mod : bdp[1][end][itr];
}
ll solve() {
ll (*last)[128] = dp[0], (*now)[128] = dp[1];
for (int j = 0; j <= c0; j++)
for (int k = 0; k <= c1; k++)
last[j][k] = (j + k <= num[0]);
for (int i = 1; i < n; i++, swap(last, now)) {
for (int j = 0; j <= c0; j++) {
bdp[0][j][0] = last[j][0];
for (int k = 1; k <= c1; k++)
bdp[0][j][k] = (last[j][k] + bdp[0][j][k - 1]) % mod;
}
for (int k = 0; k <= c1; k++) {
bdp[1][0][k] = last[0][k];
for (int j = 1; j <= c0; j++)
bdp[1][j][k] = (last[j][k] + bdp[1][j - 1][k]) % mod;
}
for (int j = 0; j <= c0; j++)
for (int k = 0; k <= c1; k++)
if (j - 1 >= 0 && k + 1 <= c1) {
//如果可以划分三角形用三角形去算
now[j][k] = now[j - 1][k + 1] + get_sum(0, j, k - num[i], k) - get_sum(1, k + 1, j - num[i] - 1, j - 1) + mod;
now[j][k] %= mod;
}
else {
//如果不能划分三角形,就用前缀和去算
now[j][k] = 0;
for (int t = 0; t <= j && t <= num[i]; t++) {
now[j][k] += get_sum(0, j - t, k - num[i] + t, k);
now[j][k] %= mod;
}
}
}
return last[c0][c1];
}
int main() {
while (scanf("%d%d%d%d", &n,&c0,&c1,&c2) != EOF) {
for (int i = 0; i < n; ++i)
scanf("%d", num + i);
int s = 0;
for (int i = 0; i < n; i++)
s += num[i];
if (c0 + c1 + c2 != s) {
printf("0\n");
continue;
}
printf("%lld\n", solve());
}
return 0;
}