题意
给一个长度为 n 的序列 p1, p2, …, pn 和 m 个二元组 (a1, b1),(a2, b2),…,(am, bm). 排列数列 p,使得 ∑ i = 1 m ∣ p a i − p b i ∣ \sum_{i = 1}^m |p_{a_i} - p_{b_i}| ∑i=1m∣pai−pbi∣ 最小。求最小值。
分析
按序放置元素,状态表示位置是否为空。Bitcnt(status)表示此时该选择的元素。
e
l
e
=
p
[
b
i
t
c
n
t
(
p
r
e
)
]
\\ ele = p[bitcnt(pre)]
ele=p[bitcnt(pre)]
若我们以当前状态中所包含的二元组的绝对值作为状态取值的话,状态内部元素的排序方式对状态转移会有影响。
绝对值
∣
p
a
i
−
p
b
i
∣
|p_{a_i} - p_{b_i}|
∣pai−pbi∣可以理解成,较大值贡献正权,较小值贡献负权。我们可以将状态取值设为:当前状态中所包含元素贡献值的最小值。
对于已经放置的元素,他们都比当前元素小,所以,当前元素在相关的绝对值中贡献正权;对于未放置的元素,他们比当前元素大,当前元素在绝对值中贡献负权。状态转移方程为:
d
p
[
s
t
a
]
=
d
p
[
p
r
e
]
+
e
l
e
∗
(
c
n
t
2
−
c
n
t
1
)
=
d
p
[
p
r
e
]
+
e
l
e
∗
(
2
∗
c
n
t
1
−
t
o
t
)
\begin{aligned} \\ dp[sta] &= dp[pre]+ele*(cnt2-cnt1) \\ &= dp[pre]+ele*(2*cnt1-tot) \end{aligned}
dp[sta]=dp[pre]+ele∗(cnt2−cnt1)=dp[pre]+ele∗(2∗cnt1−tot)
其中,pre是上一状态;pos是当前元素放置的位置;tot表示与这个位置有关位置数量,cnt1表示与这个位置有关的已被占据位置的数量,cnt2表示与这个位置有关的未被占据位置的数量。
代码
// 状态压缩, 巧妙利用了绝对值|x1-x2| (x1>x2)
#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <vector>
using namespace std;
const int MAX_N = 20;
const long long INF_ = 0x3f3f3f3f3f3f3f3f;
int p[MAX_N+5];
long long dp[(1 << MAX_N)];
int cnt_bit[1<<MAX_N];
int bitCnt(int x)
{
int res = 0;
while (x > 0) {
res++;
x = x & (x-1);
}
return res;
}
int main()
{
for (int j = 0; j < (1 << MAX_N); j++) cnt_bit[j] = bitCnt(j);
int n, m;
while (scanf("%d%d", &n, &m) != EOF)
{
vector<int> conn[MAX_N+1];
for (int i = 0; i < n; i++) scanf("%d", &p[i]);
sort(p, p+n);
for (int i = 0; i < m; i++)
{
int a, b;
scanf("%d%d", &a, &b);
conn[a-1].push_back(b-1);
conn[b-1].push_back(a-1);
}
// 这里使用memset会超时
// memset(dp, INF_, sizeof(dp));
for (int i = 1; i < (1 << n); i++) dp[i] = INF_;
dp[0] = 0;
// 状态
for (int sta = 1; sta < (1<<n); sta++)
{
// dp[sta] = INF_;
for (int pos = 0; pos < n; pos++)
{
if (((1 << pos) & sta) > 0)
{
int pre_sta = (1<<pos) ^ sta;
int cnt = 0;
for (int val:conn[pos]) if (((1 << val) & pre_sta) > 0) cnt++;
long long next_val = dp[pre_sta]+(2*cnt-conn[pos].size())*p[cnt_bit[pre_sta]];
dp[sta] = min(dp[sta], next_val);
}
}
}
printf("%lld\n", dp[(1<<n)-1]);
}
}