一、题目
二、解法
网上的题解貌似学得不是很清楚(反正蒟蒻作者是很难看懂的),我就来补一发详细题解吧。
0x01 暴力dp
期望的题好像暴力都很难写 。
设
f
[
i
]
[
j
]
[
k
]
f[i][j][k]
f[i][j][k]表示前
i
i
i个人,有
j
j
j个人喜欢第
k
k
k件衣服的概率,则有转移:
f
[
i
]
[
j
]
[
k
]
=
f
[
i
−
1
]
[
j
−
1
]
[
k
]
∗
p
[
i
]
[
k
]
+
f
[
i
−
1
]
[
j
]
[
k
]
∗
(
1
−
p
[
i
]
[
k
]
)
f[i][j][k]=f[i-1][j-1][k]*p[i][k]+f[i-1][j][k]*(1-p[i][k])
f[i][j][k]=f[i−1][j−1][k]∗p[i][k]+f[i−1][j][k]∗(1−p[i][k])
p
[
i
]
[
k
]
p[i][k]
p[i][k]表示第
i
i
i个人喜欢第
k
k
k件衣服的概率。
设
g
[
i
]
[
k
]
g[i][k]
g[i][k]表示准备了
i
i
i件第
k
k
k种衣服,送出衣服的期望,那么:
g
[
i
]
[
k
]
=
∑
j
=
0
i
j
∗
f
[
n
]
[
j
]
[
k
]
+
i
∗
∑
j
=
i
+
1
n
f
[
n
]
[
j
]
[
k
]
g[i][k]=\sum_{j=0}^{i} j*f[n][j][k]+i*\sum_{j=i+1}^{n} f[n][j][k]
g[i][k]=∑j=0ij∗f[n][j][k]+i∗∑j=i+1nf[n][j][k]
然后对
g
g
g数组进行背包即可,时间复杂度
O
(
n
2
m
)
O(n^{2}m)
O(n2m)。
0x02 优化
首先有一个式子:
∑
i
=
0
n
f
[
n
]
[
i
]
[
k
]
=
1
\sum_{i=0}^{n} f[n][i][k]=1
∑i=0nf[n][i][k]=1(似乎很显然,但后面的变化要用到)
观察
g
g
g数组,对它做差:
g
[
i
+
1
]
[
k
]
−
g
[
i
]
[
k
]
=
∑
j
=
i
+
1
n
f
[
n
]
[
j
]
[
k
]
=
1
−
∑
j
=
1
i
f
[
n
]
[
j
]
[
k
]
g[i+1][k]-g[i][k]=\sum_{j=i+1}^{n} f[n][j][k]=1-\sum_{j=1}^{i}f[n][j][k]
g[i+1][k]−g[i][k]=∑j=i+1nf[n][j][k]=1−∑j=1if[n][j][k]
发现
g
g
g数组是单调递增的,且随着
i
i
i的增大,他们的差单调变小。
这就给了我们贪心的依据,我们只需要选差最大的
g
g
g,把它加入答案即可。
但是我们仍需对
f
f
f数组进行优化,由于期望的线性性,不同的物品是互不相关的(看
d
p
dp
dp式也看得出来),有一些不优的物品就不需要把它的
d
p
dp
dp全部处理出来(用多少算多少)。
于是我们重新定义
d
p
dp
dp,设
f
[
i
]
[
j
]
f[i][j]
f[i][j]为对于第
i
i
i件衣服,前
j
j
j人有
c
n
t
[
i
]
cnt[i]
cnt[i]个人喜欢的概率(转移用滚动数组),
g
[
i
]
g[i]
g[i]为准备了
c
n
t
[
i
]
cnt[i]
cnt[i]件第
i
i
i种衣服,送出的期望,转移是注意用前缀和优化一下。
时间复杂度
O
(
n
2
+
n
m
)
O(n^{2}+nm)
O(n2+nm)
#include <cstdio>
#include <cstring>
const int MAXN = 3005;
int read()
{
int num=0,flag=1;char c;
while((c=getchar())<'0'||c>'9')if(c=='-')flag=-1;
while(c>='0'&&c<='9')num=(num<<3)+(num<<1)+(c^48),c=getchar();
return num*flag;
}
int n,m,cnt[MAXN];
double ans,p[MAXN][305],dp[305][MAXN];
double delt[MAXN],sv[MAXN],siv[MAXN],val[MAXN];
void updata(int c)
{
double g[MAXN]={};
memcpy(g,dp[c],sizeof g);
dp[c][0]=0;
for(int i=1;i<=n;i++)
dp[c][i]=dp[c][i-1]*(1-p[i][c])+g[i-1]*p[i][c];
cnt[c]++;
siv[c]+=cnt[c]*dp[c][n];sv[c]+=dp[c][n];
double e=siv[c]+cnt[c]*(1-sv[c]);
delt[c]=e-val[c];
val[c]=e;
}
int main()
{
n=read();m=read();
for(int i=1;i<=n;i++)
for(int j=1;j<=m;j++)
{
scanf("%lf",&p[i][j]);
p[i][j]/=1000.0;
}
for(int i=1;i<=m;i++)
{
dp[i][0]=1.0;
for(int j=1;j<=n;j++)
dp[i][j]=dp[i][j-1]*(1-p[j][i]);
sv[i]=dp[i][n];
updata(i);
}
for(int i=1;i<=n;i++)
{
int Max=0;
for(int j=1;j<=m;j++)
if(delt[Max]<delt[j])
Max=j;
if(!Max) break;
ans+=delt[Max];
updata(Max);
}
printf("%.12lf",ans);
}