题面
一颗星星可以抽象成 K 维空间中的一个整点。称若干星星构成的集合 S 是奇妙的,当且仅当存在 K 维空间中的整点 P(P 处可以有星星也可以没有),P 与 S 中的每颗星星至少有一维坐标相同。
有一个长度为 N 的星星序列 A,请你求出所有奇妙子段的长度之和。
N ≤ 100000 , K ≤ 5 N \leq 100000,K \leq 5 N≤100000,K≤5
题解
邓老师说的太详细了,我不好意思说别的了
说一些实现上的细节:
- d p dp dp 的第二维只需要 5 ! = 120 5!=120 5!=120 的大小,把这些状态离散化。
- 记录每个 d p [ i ] [ S ] dp[i][S] dp[i][S] 表示的方案中,最后一个匹配 J J J 号锦囊的位置 p o s [ i ] [ S ] [ J ] pos[i][S][J] pos[i][S][J] 。为了空间, d p dp dp 和 p o s pos pos 都可以开滚动。
- 枚举前驱状态时,要把 S S S 的第一位 J J J 依次放到 S ′ S' S′ 的第 1 , 2 , . . . , K 1,2,...,K 1,2,...,K 位上,判定 a [ p o s [ i − 1 ] [ S ′ ] [ J ] ] [ J ] a\big[pos[i-1][S'][J]\big]\big[J\big] a[pos[i−1][S′][J]][J] 与 a [ i ] [ J ] a[i][J] a[i][J] 是否相等,一旦判定成功,就从该处转移, d p [ i ] [ S ] = d p [ i − 1 ] [ S ′ ] + 1 dp[i][S]=dp[i-1][S']+1 dp[i][S]=dp[i−1][S′]+1 ,再把 p o s pos pos 挪过来, p o s [ i ] [ S ] [ J ] = i pos[i][S][J]=i pos[i][S][J]=i ,然后结束枚举。由于这里是复杂度瓶颈( O ( n 2 K K ) O(n2^KK) O(n2KK)),这里的所有 S ′ S' S′ 可以预处理出来,减少除法和取模的使用。
- 若都无法判定成功,那么就从 S ′ = S [ 1 : ] + J S'=S[1:]+J S′=S[1:]+J ( S S S 的第一个元素 J J J 放到最后)转移过来,由于后方与 a [ i ] [ J ] a[i][J] a[i][J] 无法匹配,或者直接到底了(令到底的 p o s pos pos 为 0),所以 d p [ i ] [ S ] = i − p o s [ i − 1 ] [ S ′ ] [ J ] dp[i][S]=i-pos[i-1][S'][J] dp[i][S]=i−pos[i−1][S′][J] 。
CODE
#include<map>
#include<set>
#include<cmath>
#include<queue>
#include<stack>
#include<random>
#include<vector>
#include<cstdio>
#include<cstring>
#include<iostream>
#include<algorithm>
using namespace std;
#define MAXN 100005
#define LL long long
#define ULL unsigned long long
#define ENDL putchar('\n')
#define DB double
#define lowbit(x) (-(x) & (x))
#define FI first
#define SE second
int xchar() {
static const int maxn = 1000000;
static char b[maxn];
static int pos = 0,len = 0;
if(pos == len) pos = 0,len = fread(b,1,maxn,stdin);
if(pos == len) return -1;
return b[pos ++];
}
//#define getchar() xchar()
LL read() {
LL f = 1,x = 0;int s = getchar();
while(s < '0' || s > '9') {if(s<0)return -1;if(s=='-')f=-f;s = getchar();}
while(s >= '0' && s <= '9') {x = (x<<1) + (x<<3) + (s^48);s = getchar();}
return f*x;
}
void putpos(LL x) {if(!x)return ;putpos(x/10);putchar((x%10)^48);}
void putnum(LL x) {
if(!x) {putchar('0');return ;}
if(x<0) putchar('-'),x = -x;
return putpos(x);
}
void AIput(LL x,int c) {putnum(x);putchar(c);}
int n,m,s,o,k;
int mp[MAXN],nm[MAXN],cnt;
bool f[10];
void dfs0(int x,int S) {
if(x > m) {
mp[S] = ++ cnt; nm[cnt] = S;
return ;
}
for(int i = m-1;i >= 0;i --) {
if(!f[i]) {
f[i] = 1;
dfs0(x+1,S*m+i);
f[i] = 0;
}
}return ;
}
int a[MAXN][5];
map<int,int> ma[5];
int pe[MAXN][5];
int dp[2][125];
int ps[2][125][5];
int nx[125][5];
int main() {
n = read(); m = read();
dfs0(1,0);
for(int i = 1;i <= n;i ++) {
for(int j = 0;j < m;j ++) {
a[i][j] = read();
pe[i][j] = ma[j][a[i][j]];
ma[j][a[i][j]] = i;
}
}
for(int i = 1;i <= cnt;i ++) {
int P = nm[i],J = nm[i]%m;
for(int j = 0,pw = 1;j < m;j ++,pw *= m) {
nx[i][j] = P;
int y = (P/pw/m)%m;
P += (J-y)*pw*m + (y-J)*pw;
}
}
LL ans = 0;
int Pw = 1;
for(int i = 1;i < m;i ++) Pw *= m;
for(int i = 1;i <= n;i ++) {
memset(dp[i&1],0,sizeof(dp[i&1]));
memset(ps[i&1],0,sizeof(ps[i&1]));
int kk = (i&1)^1;
int mx = 0;
for(int S = 1;S <= cnt;S ++) {
int x = nm[S],J = x%m,pr = mp[x/m+J*Pw];
dp[i&1][S] = min(dp[kk][pr]+1,i-ps[kk][pr][J]);
memcpy(ps[i&1][S],ps[kk][pr],sizeof(ps[kk][pr]));
ps[i&1][S][J] = i;
if(pe[i][J]) {
int P = x;
for(int j = 0;j < m;j ++) {
P = nx[S][j];
if(a[i][J] == a[ps[kk][mp[P]][J]][J]) {
dp[i&1][S] = dp[kk][mp[P]] + 1;
memcpy(ps[i&1][S],ps[kk][mp[P]],sizeof(ps[kk][mp[P]]));
ps[i&1][S][J] = i; break;
}
}
}
mx = max(mx,dp[i&1][S]);
}
ans += mx;
}
AIput(ans,'\n');
return 0;
}