题目描述:
题目分析:
看到
b
b
b的长度为5,可以感觉到这题就是在锻炼强大合理的分类讨论能力。
首先看
b
i
≤
2
b_i\le2
bi≤2的部分分,即只有两种数字。
枚举这两种数为
x
,
y
x,y
x,y,数量分别为
c
n
t
x
,
c
n
t
y
cnt_x,cnt_y
cntx,cnty,可以记状态
f
[
i
]
[
5
]
f[i][5]
f[i][5],在
O
(
(
c
n
t
x
+
c
n
t
y
)
∗
5
)
O((cnt_x+cnt_y)*5)
O((cntx+cnty)∗5)的时间内做一个背包求出答案,每种数最多算
n
n
n次,总复杂度
O
(
n
2
)
O(n^2)
O(n2)
如果 b b b中有两个数字出现次数大于一次(大于一次的数最多只有两个),剩下的那个数可以在背包的时候直接算。
如果只有一个数字出现次数大于一次,枚举这种数为 x x x,进行 O ( n ) O(n) O(n)的dp算出不考虑其它数字是否相同的答案,然后减去多算的(即有三种数和两种数的情况,枚举转化为上面的问题,实现可以看代码)。
如果所有数字都不相同,直接将每个数字的出现次数拿来dp即可。
乍一看这道题这种等价问题需要知道前面的状态,但是通过枚举将其转化为了dp的转移条件,妙哉。
分类讨论时通过容斥化归问题,妙哉。
Code:
#include<bits/stdc++.h>
#define maxn 3005
#define LL long long
using namespace std;
int n,a[maxn],b[6],c[6];
vector<int>G[maxn];
LL solve1(){
int c[maxn]={0};
for(int i=1;i<=n;i++) c[a[i]]++;
LL f[6]={1};
for(int i=1;i<=n;i++) if(c[i]) for(int j=5;j>=1;j--) f[j]+=f[j-1]*c[i];
return f[5];
}
LL solve2(int x,int y){
LL ret=0;
for(int X=1;X<=n;X++) if(G[X].size()) for(int Y=1;Y<=n;Y++) if(X!=Y&&G[Y].size()){
LL f[6]={1}; int i=0,j=0,pre=0;
while(i<G[X].size()||j<G[Y].size()){
if(j==G[Y].size()||(i<G[X].size()&&G[X][i]<G[Y][j])){
for(int k=1;k<=5;k++) if(b[k]!=x&&b[k]!=y) f[k]+=f[k-1]*(G[X][i]-pre-1);
for(int k=5;k>=1;k--) if(b[k]==x) f[k]+=f[k-1];
pre=G[X][i++];
}
else{
for(int k=1;k<=5;k++) if(b[k]!=x&&b[k]!=y) f[k]+=f[k-1]*(G[Y][j]-pre-1);
for(int k=5;k>=1;k--) if(b[k]==y) f[k]+=f[k-1];
pre=G[Y][j++];
}
}
if(b[5]!=x&&b[5]!=y) f[5]+=f[4]*(n-pre);
ret+=f[5];
}
return ret;
}
LL solve3(int x){
LL ret=0;
for(int i=1;i<=n;i++) if(G[i].size()){
LL f[6]={1};
for(int j=1;j<=n;j++)
if(a[j]==i) {for(int k=5;k>=1;k--) if(b[k]==x) f[k]+=f[k-1];}
else {for(int k=5;k>=1;k--) if(b[k]!=x) f[k]+=f[k-1];}
ret+=f[5];
}
for(int i=1;i<=5;i++) if(b[i]!=x) for(int j=i+1;j<=5;j++) if(b[j]!=x){
int t1=b[j]; b[j]=b[i];
ret-=solve2(x,b[i]);
for(int k=j+1;k<=5;k++) if(b[k]!=x){
int t2=b[k]; b[k]=b[i];
ret-=solve2(x,b[i]);
b[k]=t2;
}
b[j]=t1;
}
return ret;
}
int main()
{
scanf("%d",&n);
for(int i=1;i<=n;i++) scanf("%d",&a[i]),G[a[i]].push_back(i);
for(int i=1;i<=5;i++) scanf("%d",&b[i]),c[b[i]]++;
int cnt=0,x=0,y;
for(int i=1;i<=5;i++)
if(c[i]>=2) cnt++,!x?(x=i):(y=i);
//cnt<=2
printf("%lld\n",!cnt?solve1():cnt>1?solve2(x,y):solve3(x));
}