Description
给出一个长度为 m 的序列 A, 请你求出有多少种 1…n 的排列, 满足 A 是它的一个 LIS.
Input
第一行两个整数 n,m.
接下来一行 m 个整数, 表示 A.
Output
一行一个整数表示答案.
Sample Input
5 3
1 3 4
Sample Output
11
Data Constraint
对于前 30% 的数据, n ≤ 9;
对于前 60% 的数据, n ≤ 12;
对于 100% 的数据, 1 ≤ m ≤ n ≤ 15.
题解
数据范围很小,
如果只是暴力枚举,无论怎么剪枝,应该也只能有60分。
考虑在一般情况下如何求最长上升子序列,
设
fi
f
i
长度为i的最长上升子序列的结尾最小是什么。
f数组一定是递增的,也就是说只要知道f数组中包含了哪一些数,就可以将f数组还原出来。
设
fs1,s2
f
s
1
,
s
2
表示,选了数的状态是s1,在f数组里面的状态s2,
但是这样的状态数太多了。
考虑到s2时包含在s1里面的,那么这两个状态就可合并。
用一个三进制状态来表示:
0:没有选。
1:选了,并且在f中。
2:选了,但不在f中。
按照题目中a数组的顺序来插入,
最后就对全部最长上升序列长度为m的求和。
code
#include <queue>
#include <cstdio>
#include <iostream>
#include <algorithm>
#include <cstring>
#include <string.h>
#include <cmath>
#include <math.h>
#include <time.h>
#define ll long long
#define N 100003
#define M 103
#define db double
#define P putchar
#define G getchar
#define inf 998244353
using namespace std;
char ch;
void read(int &n)
{
n=0;
ch=G();
while((ch<'0' || ch>'9') && ch!='-')ch=G();
ll w=1;
if(ch=='-')w=-1,ch=G();
while('0'<=ch && ch<='9')n=(n<<3)+(n<<1)+ch-'0',ch=G();
n*=w;
}
int max(int a,int b){return a>b?a:b;}
int min(int a,int b){return a<b?a:b;}
ll abs(ll x){return x<0?-x:x;}
ll sqr(ll x){return x*x;}
void write(ll x){if(x>9) write(x/10);P(x%10+'0');}
ll f[14348987],ans;
int n,m,a[20],ss,t[20],gg[20],g[20],z[20],bz[20],x;
bool ina[20],pd;
void get(int s)
{
memset(bz,0,sizeof(bz));
memset(g,0,sizeof(g));
int x;
for(int i=n;i;i--)
{
x=s/z[i-1];
if(x)bz[i]=1;
if(x==2)g[++g[0]]=i;
s=s%z[i-1];
}
}
int work()
{
int s=0,all[20];
memset(all,0,sizeof(all));
for(int i=1;i<=n;i++)
if(bz[i])all[i]++;
for(int i=1;i<=g[0];i++)
all[g[i]]++;
for(int i=1;i<=n;i++)
s+=all[i]*z[i-1];
return s;
}
void dg(int x,int y,int s)
{
if(y==m)
{
ans+=f[ss*2-s];
return;
}
if(x>n)return;
dg(x+1,y+1,s+z[x-1]);
dg(x+1,y,s);
}
int main()
{
z[0]=1;
for(int i=1;i<20;i++)
z[i]=z[i-1]*3;
read(n);read(m);
for(int i=1;i<=m;i++)
read(a[i]),ina[a[i]]=1;
memset(f,0,sizeof(f));
f[0]=1;
for(int s=0;s<z[n];s++)
{
if(!f[s])continue;
for(int i=1;i<=n;i++)
{
if(s/z[i-1]%3)continue;
if(ina[i])
{
pd=1;
for(int j=1;a[j]<i;j++)
pd=pd&&(s/z[a[j]-1]%3);
if(!pd)continue;
}
pd=1;
for(int k=i+1;k<=n;k++)
if(s/z[k-1]%3==1)
{
f[s+z[k-1]+z[i-1]]+=f[s];
pd=0;break;
}
if(pd)f[s+z[i-1]]+=f[s];
}
}
ss=0;
for(int i=0;i<n;i++)
ss+=z[i];
dg(1,0,0);
write(ans);
return 0;
}