题目
参考自zhou888的代码
Solution
首先有一个很明显的贪心策略:强化牌选得越多越好,当然,要攻击牌打得出去才行
当
i
<
k
i<k
i<k时,强化牌打
i
i
i张,攻击牌打
m
−
i
m-i
m−i张
当
i
≥
k
i≥k
i≥k时,强化牌打
k
−
1
k-1
k−1张,攻击牌打
m
−
k
+
1
m-k+1
m−k+1张
两者结合,就是强化牌打
m
i
n
(
i
,
k
−
1
)
min(i,k-1)
min(i,k−1)张,攻击牌打
m
−
m
i
n
(
i
,
k
−
1
)
m-min(i,k-1)
m−min(i,k−1)张
另一个很显然的就是:牌的顺序与最终答案无关,所以我们排序后,不会对最终答案造成影响
具体排序的正确性留坑
g
i
,
j
g_{i,j}
gi,j表示前
i
i
i张强化牌,选
j
j
j张,所能取得的最优倍率之和
当
j
≤
k
−
1
j≤k-1
j≤k−1时,
g
i
,
j
=
a
i
∑
j
≤
m
i
n
(
i
,
m
)
g
i
−
1
,
j
−
1
g_{i,j}=a_i\sum_{j≤min(i,m)}g_{i-1,j-1}
gi,j=ai∑j≤min(i,m)gi−1,j−1
当
j
>
k
−
1
j>k-1
j>k−1时,
g
i
,
j
=
∑
j
≤
m
i
n
(
i
,
m
)
g
i
−
1
,
j
−
1
g_{i,j}=\sum_{j≤min(i,m)}g_{i-1,j-1}
gi,j=∑j≤min(i,m)gi−1,j−1
f
i
,
j
f_{i,j}
fi,j表示前
i
i
i张攻击牌,选
j
j
j张,所能取得的最优攻击之和
当
m
−
j
<
k
−
1
(
j
>
m
−
k
+
1
)
m-j<k-1(j>m-k+1)
m−j<k−1(j>m−k+1)时,
f
i
,
j
=
∑
j
≤
m
i
n
(
i
,
m
)
(
j
−
1
i
−
1
)
a
i
+
f
i
−
1
,
j
−
1
f_{i,j}=\sum_{j≤min(i,m)}(^{i-1}_{j-1})a_i+f_{i-1,j-1}
fi,j=∑j≤min(i,m)(j−1i−1)ai+fi−1,j−1
当
m
−
j
≥
k
−
1
(
j
≤
m
−
k
+
1
)
m-j≥k-1(j≤m-k+1)
m−j≥k−1(j≤m−k+1)时,
f
i
,
j
=
∑
j
≤
m
i
n
(
i
,
m
)
(
j
−
1
i
−
1
)
a
i
f_{i,j}=\sum_{j≤min(i,m)}(^{i-1}_{j-1})a_i
fi,j=∑j≤min(i,m)(j−1i−1)ai
(
j
−
1
i
−
1
)
(^{i-1}_{j-1})
(j−1i−1)的意义是
a
i
a_i
ai可以更新所有当前已经计算过的答案
Code
#include<bits/stdc++.h>
using namespace std;
const int N=3001,M=998244353;
int i,j,k,n,m,a[N],f[N],g[N],c[N][N],T,ans;
inline char gc(){
static char buf[100000],*p1=buf,*p2=buf;
return p1==p2&&(p2=(p1=buf)+fread(buf,1,100000,stdin),p1==p2)?EOF:*p1++;
}
inline int rd(){
int x=0,fl=1;char ch=gc();
for (;ch<48||ch>57;ch=gc())if(ch=='-')fl=-1;
for (;48<=ch&&ch<=57;ch=gc())x=(x<<3)+(x<<1)+(ch^48);
return x*fl;
}
int main(){
for (i=0;i<N;i++)
for (j=1,c[i][0]=1;j<=i;j++) c[i][j]=(c[i-1][j]+c[i-1][j-1])%M;
for (T=rd();T--;){
n=rd(),m=rd(),k=rd();
for (i=1;i<=n;i++) a[i]=rd();
sort(a+1,a+n+1,greater<int>());
memset(g,0,m+1<<2);
memset(f,0,m+1<<2);
g[0]=1;
for (i=1;i<=n;i++)
for (j=min(i,m);j;j--)
if (j<=k-1) g[j]=(g[j]+1ll*g[j-1]*a[i])%M;
else (g[j]+=g[j-1])%=M;
for (i=1;i<=n;i++) a[i]=rd();
sort(a+1,a+n+1);
for (i=1;i<=n;i++)
for (j=min(i,m);j;j--) f[j]=(f[j]+1ll*c[i-1][j-1]*a[i]+(j>m-k+1)*f[j-1])%M;
ans=0;
for (i=0;i<m;i++) ans=(ans+1ll*g[i]*f[m-i])%M;
printf("%d\n",ans);
}
}