首先枚举每个L的拐点是哪一行,n^3
然后一行dp过去,每个L有三种状态:还没放,放下在持续中(即L的横还没结束),已经放完了,用3^3表示这个状态
dp[m][3^3],然后转移枚举2^3,看每个L的状态是否要变成下一个状态,除了0->1这种状态转变,其他状态转变方案数都是乘以1
0->1就是看这个位置向上最多能延伸多少(即L的竖线,要看地图,还有看其他的L)
唔,这方法复杂度有点高:n^3*m*27*8,不过最大的数据600ms左右也跑出来了。。。。。(不像正解啊T T
#include <vector>
#include <list>
#include <map>
#include <set>
#include <deque>
#include <stack>
#include <bitset>
#include <algorithm>
#include <functional>
#include <numeric>
#include <utility>
#include <sstream>
#include <iostream>
#include <iomanip>
#include <cstdio>
#include <cmath>
#include <cstdlib>
#include <ctime>
#include <cstring>
using namespace std;
class ThreeLLogo {
public:
long long countWays(vector<string> );
};
long long dp[35][27];
int up[33][33];
long long ThreeLLogo::countWays(vector<string> s) {
int i, j, k;
int a, b, c;
int n, m;
int t1, t2, t3, tt1, tt2, tt3;
n = s.size();
m = s[0].size();
long long ans = 0;
for (i = 0; i < n; ++i) {
for (j = 0; j < m; ++j) {
if (s[i][j] == '#')
up[i][j] = -1;
else if (i == 0)
up[i][j] = 0;
else
up[i][j] = up[i - 1][j] + 1;
}
}
for (a = 1; a < n; ++a) {
for (b = a; b < n; ++b) {
for (c = b; c < n; ++c) {
memset(dp, 0, sizeof(dp));
dp[0][0] = 1;
for (i = 0; i < m; ++i) {
for (j = 0; j < 27; ++j) {
if (dp[i][j] == 0)
continue;
t1 = j % 3;
t2 = j / 3 % 3;
t3 = j / 9 % 3;
if (t1 == 1 && s[a][i] == '#')
continue;
if (t2 == 1 && s[b][i] == '#')
continue;
if (t3 == 1 && s[c][i] == '#')
continue;
if (a == b && t1 < t2)
continue;
if (a == c && t1 < t3)
continue;
if (b == c && t2 < t3)
continue;
for (k = 0; k < 8; ++k) {
long long cost = 1;
tt1 = t1 + k % 2;
tt2 = t2 + k / 2 % 2;
tt3 = t3 + k / 4;
if (tt1 == 1 && s[a][i] == '#')
continue;
if (tt2 == 1 && s[b][i] == '#')
continue;
if (tt3 == 1 && s[c][i] == '#')
continue;
if (tt1 > 2 || tt2 > 2 || tt3 > 2)
continue;
if (t1 == 0 && tt1 == 1) {
cost *= up[a][i];
}
if (t2 == 0 && tt2 == 1) {
long long t;
t = up[b][i];
if (t1 == 1 || tt1 == 1)
t = min(t, (long long) b - a - 1);
if (t < 0)
t = 0;
cost *= t;
}
if (t3 == 0 && tt3 == 1) {
long long t;
t = up[c][i];
if (t1 == 1 || tt1 == 1)
t = min(t, (long long) c - a - 1);
if (t2 == 1 || tt2 == 1)
t = min(t, (long long) c - b - 1);
if (t < 0)
t = 0;
cost *= t;
}
int newj = tt1 + tt2 * 3 + tt3 * 9;
dp[i + 1][newj] += dp[i][j] * cost;
}
}
}
ans += dp[m][26];
}
}
}
return ans;
}