题目大意: 有n种商品, 价格为p[i], 购买时可以使用两种优惠, 一是商品降价d[i], 另一个是送你一件商品f[i]。 求使得每种商品都至少得到1件的最小花费。 (n≤105,d[i]≤q[i]≤109,f[i]≤n) ( n ≤ 10 5 , d [ i ] ≤ q [ i ] ≤ 10 9 , f [ i ] ≤ n )
题目思路: 对于第二种优惠的赠送关系可以建图, 连边f[i]指向i, n个点, n条边,每个点入度均为1, 构成一个基环树森林。 考虑普通树的情况, 令dp[i][0]表示子树i的最小费用, dp[i][1]表示以第二种优惠购买商品i(即拥有免费获得其父亲点f[i]的权利)的情况下子树i的最小费用。 按普通的树形dp来做, 有转移
在考虑有环的情况, 先不管同环上点之间的影响, 先把各个点的挂子树的值求出来, 最后一次考虑每个环。 可以从环上任一个点断开为链为cir[1…m], 用g[i][0],g[i][1]类似表示环上第i个点的答案, 考虑链上点的转移有
最后枚举第一个点的状态来设定其初始值
如果第一个点不是由最后一个点赠送而来则
g[1][0]=dp[cir[1]][0],g[1][1]=dp[cir[1]][1]
g
[
1
]
[
0
]
=
d
p
[
c
i
r
[
1
]
]
[
0
]
,
g
[
1
]
[
1
]
=
d
p
[
c
i
r
[
1
]
]
[
1
]
, 取g[m][0]更新答案
如果第一个点是由最后一个点赠送而来则
g[1][0]=∑v∈son[cir[1]]dp[v][0],g[1][1]=dp[cir[1]][1]
g
[
1
]
[
0
]
=
∑
v
∈
s
o
n
[
c
i
r
[
1
]
]
d
p
[
v
]
[
0
]
,
g
[
1
]
[
1
]
=
d
p
[
c
i
r
[
1
]
]
[
1
]
, 取g[m][1]更新答案
Code:
#include <map>
#include <set>
#include <map>
#include <cmath>
#include <queue>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <iostream>
#include <algorithm>
using namespace std;
#define ll long long
const int N = (int)1e5 + 10;
const ll inf = 1LL << 60;
int n;
int cnt, lst[N], nxt[N], to[N], fa[N];
int tim;
int vis[N], incir[N];
vector<int > cir[N]; int m;
int p[N], d[N];
ll ans, f[N][2], g[N][2], sum[N]; bool root[N];
void add(int u, int v){
nxt[++ cnt] = lst[u]; lst[u] = cnt; to[cnt] = v;
}
void dfs(int u){
vis[u] = tim;
for (int j = lst[u]; j; j = nxt[j]){
int v = to[j];
if (vis[v] == tim){
m ++;
int p = u;
for (; p != v; p = fa[p]){
cir[m].push_back(p);
incir[p] = 1;
}
cir[m].push_back(p);
incir[p] = 1;
continue;
}
fa[v] = u;
if (!vis[v]) dfs(v);
}
}
void dp(int u){
vis[u] = 1;
for (int j = lst[u]; j; j = nxt[j]){
int v = to[j];
if (!vis[v]) dp(v);
if (incir[v]) continue;
sum[u] += f[v][0];
}
f[u][1] = p[u] + sum[u];
f[u][0] = p[u] + sum[u] - d[u];
for (int j = lst[u]; j; j = nxt[j]){
int v = to[j];
if (incir[v]) continue;
f[u][0] = min(f[u][0], f[v][1] + sum[u] - f[v][0]);
}
}
int main(){
scanf("%d", &n);
for (int i = 1; i <= n; i ++) scanf("%d", p + i);
for (int i = 1; i <= n; i ++) scanf("%d", d + i);
for (int i = 1, x; i <= n; i ++) scanf("%d", &x), add(x, i);
for (int i = 1; i <= n; i ++)
if (!vis[i]) {++ tim; dfs(i);}
memset(vis, 0, sizeof(vis));
for (int i = 1; i <= n; i ++)
if (!vis[i]) dp(i);
for (int i = 1; i <= m; i ++){
ll ret = inf;
int sz = cir[i].size();
g[0][0] = f[cir[i][0]][0], g[0][1] = f[cir[i][0]][1];
for (int j = 1; j < sz; j ++){
g[j][1] = f[cir[i][j]][1] + g[j - 1][0];
g[j][0] = min(f[cir[i][j]][0] + g[j - 1][0], g[j - 1][1] + sum[cir[i][j]]);
}
ret = min(ret, g[sz - 1][0]);
g[0][0] = sum[cir[i][0]]; g[0][1] = f[cir[i][0]][1];
for (int j = 1; j < sz; j ++){
g[j][1] = f[cir[i][j]][1] + g[j - 1][0];
g[j][0] = min(f[cir[i][j]][0] + g[j - 1][0], g[j - 1][1] + sum[cir[i][j]]);
}
ret = min(ret, g[sz - 1][1]);
ans += ret;
}
printf("%lld\n", ans);
return 0;
}