题意:从N个item中uniformly and randomly取出一个,可以有K次放回重新取的机会。如果使用最优策略,取出的item value的期望值是多少。
如果K=0,期望值就是所有item value的平均值。
因为每次放回item之后,系统状态不变,所以取到的item expected value仍然是平均值。对于K=1,最优策略是,如果第一次取到item value大于平均值,就停止,否则放回继续取。
对于K>1,我本来以为最优策略也是以average为threshold决定是否放回,但后来sample test一直过不了o(╯□╰)o。后来发现,这应该是个概率dp问题,dp[k]表示剩余K次放回机会最优策略中的expected value,那么如果当前选择的item value比dp[k]大,就停止,否则就放回继续选。
递推关系是 dp[k]=sum_{i=1}^N max(V[i],dp[k-1])/N。
如果先将item按照value排好序,那么给定threshold dp[k-1]可以通过二分找出分界点pos(最大的value<threshold的item),预处理是计算前缀和,就可以O(1)计算value大于threshold的item之和。递推关系为dp[k]=pos/N*dp[k-1]+(N-pos)/N*(presum[n]-presum[pos])/(N-pos)。
#include<iostream>
#include<stdio.h>
#include<cstdio>
#include<string>
#include<cmath>
#include<stdlib.h>
#include<algorithm>
#include<string.h>
#include<cstring>
#include<vector>
#include<queue>
#include<map>
#include<set>
#include<windows.h>
using namespace std;
//Kickstart 2018 Round A Problem B
const int maxn=50010;
int T;
long long V[maxn];
int N;
double ans;
int K;
double average;
double dp[maxn];
int binary(double target)//return the smallest index whose V[i] is larger than target
{
int left=0;
int right=N-1;
int mid=(left+right)/2;
while(left<=right)
{
mid=(left+right)/2;
if(V[mid]<=target)
{
left=mid+1;
}
else
{
right=mid-1;
}
}
if(V[mid]<=target)//for the special case that all values are the same, e.g. 1,1,1
{
return mid+1;
}
return mid;
}
int main()
{
// freopen("input.txt","r",stdin);
freopen("B-large-practice.in","r",stdin);
freopen("B.txt","w",stdout);
cin>>T;
for(int ca=1;ca<=T;ca++)
{
memset(V,0,sizeof(V));
ans=0;
average=0;
cin>>N>>K;
for(int i=0;i<N;i++)
{
cin>>V[i];
}
long long sum=0;
for(int i=0;i<N;i++)
{
sum+=V[i];
}
average=1.0*sum/N;
//cout<<average<<endl;
if(K==0)
{
ans=average;
}
else if(K==1)
{
for(int i=0;i<N;i++)
{
if(V[i]>=average)
{
ans+=1.0*V[i]/N;
}
else
{
ans+=1.0*average/N;
}
}
}
else
{
dp[0]=average;
sort(V,V+N);
long long presum[maxn];
presum[0]=V[0];
for(int i=1;i<N;i++)
{
presum[i]=presum[i-1]+V[i];
}
for(int k=1;k<=K;k++)
{
double prev=dp[k-1];
int pos=binary(prev)-1;
//cout<<"pos "<<binary(prev)<<" target "<<prev<<endl;
dp[k]=(pos+1)*1.0/N*dp[k-1]+1.0*(presum[N-1]-presum[pos])/N;
}
ans=dp[K];
// long long larger_sum=0;
// int num_lower=0;
// int num_larger=0;
// for(int i=0;i<N;i++)
// {
// if(V[i]>average)
// {
// num_larger++;
// larger_sum+=V[i];
// }
// else
// {
// num_lower++;
// }
// }
// ans+=pow(1.0*num_lower/N,K)*average;
// //cout<<ans<<endl;
// if(num_larger>0)
// {
// ans+=(1.0-pow(1.0*num_lower/N,K))/(1.0-1.0*num_lower/N)*1.0/N*larger_sum;
// //cout<<(1.0-pow(1.0*num_lower/N,K-1))<<" "<<(1.0-1.0*num_lower/N)<<endl;
// }
}
printf("Case #%d: %.8f\n", ca, ans);
}
return 0;
}