题意
给一个长度为n的序列,且每个元素在[1,8]之间,
要求找出一个子序列满足
1.子序列的相同元素要连续
2.子序列每个元素出现的次数之差不超过1
求出满足要求的最长的子序列长度
题解:
首先
假设存在一个子序列满足[1,8]中有的数字出现次数是len,这样的数字有b个,剩下(8-b)个数字的出现次数是len - 1,那么这个子序列的长度是
d2 = len * b + (8 - b)*(len - 1)
可以想到二分len,然后求算出最长的len中最大的b,算出总长度。
但是二分前我们要判断下len是不是符合单调性
如果另一个子序列有a个数字长度是 len + 1,那么序列长度就是
d1 = (len + 1)* a + (8-a)*len
d1 - d2 = a + 8 - b,a>0&&b<=8 所以满足d1 - d2 > 0,所以len + 1更优,满足单调性。
接下来二分len
对每个len 都要计算满足len时最大的b,算出总长度。
所以可以dp一下
dp[i][j] 表示 从1 到 i 的序列中状态为 j 的最大的b(就是上面的例子所举的b),每个状态的二进制表示的数字中为1的表示当前状态选了哪一种数字,例如
j 为5 ==00000101,说明当前选了1和3这个数字
假如我现在要在 dp[i][j]这种时候选k这个数字
那么我可以选
1.len - 1 个 k
2.len 个 k
对第二种2.
a. nextState = j | (1<<(k-1)),表示在 j 的状态下选了k
b. nextPos = 从 i 开始计算的第len个k在原序列的位置
c. dp[nextPos][nextState] = max(…,dp[i][j] + 1);//+1表示又选了一种数字
对第一种1.
基本同第二种,b. len 改成 len - 1。c.dp[i][j]不加1
这样大概的dp就可以写出来了
/*cf 743E*/
#include <set>
#include <cstring>
#include <algorithm>
#include <iostream>
#include <cstdio>
#include <cmath>
#include <queue>
#include <stack>
#include <vector>
#define mp make_pair
#define pb push_back
#define X first
#define Y second
using namespace std;
typedef long long LL;
const int maxn = 1005;
int n, a[maxn], cur[10], dp[maxn][256 + 5];
vector<int>in[maxn];
int ok(int len) {
memset(cur, 0, sizeof cur);
memset(dp, 128, sizeof dp);
dp[1][0] = 0;
for(int i = 1; i <= n; i++) {
for(int j = 0; j < (1 << 8); j++) {
if(dp[i][j] < 0)
continue;
for(int k = 1; k <= 8; k++) {
if(j & (1 << (k - 1))) continue;//k这个数字已经有了,不选
int nextState = j | (1 << (k - 1));
int pos = cur[k] + len - 1;//添加 len - 1 个 k
if(pos > in[k].size()) continue;
int nextPos = in[k][pos - 1];//pos - 1只是因为in是从0开始的
//所以pos >= 1,也就是 len >= 2
dp[nextPos][nextState] = max(dp[nextPos][nextState], dp[i][j]);
pos++;//添加 len 个 k
if(pos > in[k].size()) continue;
nextPos = in[k][pos - 1];
dp[nextPos][nextState] = max(dp[nextPos][nextState], dp[i][j] + 1);
}
}
cur[a[i]]++;
}
int ans = -1;
for(int i = 1; i <= n; i++)
ans = max(dp[i][255], ans);
if(ans <= 0) return 0;
return ans * len + (8 - ans) * (len - 1);
}
void init() {
scanf("%d", &n);
for(int i = 1; i <= n; i++) {
scanf("%d", &a[i]);
in[a[i]].pb(i);
}
int l = 2, r = n + 1;
while(l < r) {
//cout << l << " " << r << endl;
int mid = (r + l) >> 1;
if(ok(mid)) l = mid + 1;
else r = mid;
}
l--;// r == n+1&& l == r必然说明 ok(l)不是最大=。=
if(l == 1) {
int ans = 0;
for(int i=1;i<=8;i++)
if(in[i].size()>0) ans++;
cout<<ans<<endl;
} else cout << ok(l) << endl;
}
int main() {
#ifdef LOCAL
freopen("in.txt", "r", stdin);
#endif // LOCAL
init();
return 0;
}