题意:
给你四个序列:a,p,b,c;长度均为n。
你可以任意两两组合序列b与序列c中的数,得到新的序列 d=b+c 对应项相加。
新的序列中随机与a序列比较,一共有n!种顺序,如果$d_i > a_i$ 则获得$p_i$的权值
问:在b与c的最优组合方案下,权值和的最大期望*n是多少?
首先考虑对于一个
d
i
=
b
x
+
c
y
d_i=b_x+c_y
di=bx+cy 如果其满足
d
i
>
a
j
d_i>a_j
di>aj ,那么他对答案贡献为
(
n
−
1
)
!
n
!
∗
p
j
\frac{(n-1)!}{n!}*p_j
n!(n−1)!∗pj
这个问题可以这样考虑:
对于一个确定的
d
i
>
a
j
d_i>a_j
di>aj我们保持a序列顺序不变,排列d序列一共n!种方案,如果恰好要使这个
d
i
>
a
j
d_i>a_j
di>aj,我们保持
d
i
d_i
di与
a
j
a_j
aj相对顺序不变,一共(n-1)!种方案,而满足条件的方案对于答案的贡献为
p
j
p_j
pj。
然后由于答案是最大期望*n的结果,所以这样一个点对于答案的贡献为
p
j
p_j
pj
知道上面这个性质后,问题其实就转化为一个二分图最大权匹配,我们设b序列代表的点
u
∈
U
u \in U
u∈U,设c序列代表的点
v
∈
V
v \in V
v∈V,u和v匹配获得的权值为序列a中小于
b
i
+
c
j
b_i+c_j
bi+cj的
p
i
p_i
pi的和。这样建图跑KM就可以求出结果。
建图复杂度为
O
(
n
2
l
o
g
n
)
O(n^2logn)
O(n2logn)使用前缀和优化,不优化的复杂度可能为O(n^3) 二分图最大权匹配
O
(
n
3
)
O(n^3)
O(n3)
另外有一个坑点是有些KM板子会被卡。
代码:
#include<bits/stdc++.h>
#define ll long long
using namespace std;
const int N=407,inf=0x3f3f3f3f;
const ll INF=0x3f3f3f3f3f3f3f3f;
ll a[N],b[N],c[N],p[N],psum[N];
map<ll,ll> val;
vector<ll> aa;
ll ex[N],ey[N],slack[N];
ll m[N][N];
int v1[N],v2[N],match[N],pre[N];
int n;
void find(int u)
{
int x, y = 0,yy;
ll del;
memset(pre, 0, sizeof(pre));
for(int i = 1;i <= n;i ++) slack[i] = inf;
match[y] = u;
while(1)
{
x = match[y], del = inf, v2[y] = 1;
for(int i = 1;i <= n;i ++)
{
if(v2[i]) continue;
if(slack[i] > ex[x] + ey[i] - m[x][i])
{
slack[i] = ex[x] + ey[i] - m[x][i];
pre[i] = y;
}
if(slack[i] < del) del = slack[i], yy = i;
}
for(int i = 0;i <= n;i ++)
{
if(v2[i]) ex[match[i]] -= del, ey[i] += del;
else slack[i] -= del;
}
y = yy;
if(match[y] == -1) break;
}
while(y){match[y] = match[pre[y]]; y = pre[y];}
}
ll KM()
{
memset(match, -1, sizeof(match));
memset(ex, 0, sizeof(ex));
memset(ey, 0, sizeof(ey));
for(int i = 1;i <= n;i ++)
{
memset(v2, 0, sizeof(v2));
find(i);
}
ll res = 0;
for(int i = 1;i <= n;i ++)
if(match[i] != -1) res += m[match[i]][i];
return res;
}
int main()
{
scanf("%d",&n);
for(int i=1;i<=n;i++)
{
scanf("%lld",&a[i]);
aa.push_back(a[i]);
}
sort(aa.begin(),aa.end());
aa.erase(unique(aa.begin(),aa.end()),aa.end());
for(int i=1;i<=n;i++)
{
scanf("%lld",&p[i]);
val[a[i]]+=p[i];
}
for(int i=0;i<aa.size();i++) psum[i]=psum[i-1]+val[aa[i]];
for(int i=1;i<=n;i++) scanf("%lld",&b[i]);
for(int i=1;i<=n;i++) scanf("%lld",&c[i]);
for(int i=1;i<=n;i++)
{
for(int j=1;j<=n;j++)
{
ll x=b[i]+c[j];
int pos=lower_bound(aa.begin(),aa.end(),x)-aa.begin()-1;
m[i][j]=psum[pos];
}
}
printf("%lld\n",KM());
}