Pre
之前做过这一题,方法比较麻烦。
Solution
基环树乱搞。
简便的是把环断开,以两个断点为根进行 \(DP\) 这样就可以算出答案。
注意到 \(STD\) 代码里面是有一句
dp[v][1] = -1e9;
这样就强制的选了一个,不选另外一个。
Code
Code(!std)
#include<cstdio>
#include<cstring>
#include<string>
#include<iostream>
#include<unistd.h>
#include<utility>
#include<map>
#include<limits.h>
#include<cmath>
#include<algorithm>
#include<queue>
#include<set>
#include<stack>
#include<vector>
#define rint register int
#define ll long long
#define I inline
#define ill I ll
#define iint I int
#define ivoid I void
#define ifloat I float
#define idouble I double
#define ibool bool
#define ipair I pair
#define xx first
#define yy second
using namespace std;
int mod;
ill qpow(ll m,ll n)
{
ll base=m,tot=1;
while(n)
{
if(n%2==1){
tot*=base;
tot%=mod;
}
base*=base;
base%=mod;
n/=2;
}
return tot%mod;
}
struct _in{
const _in&operator,(rint&a)const{
a=0;
rint f=1;
char k=getchar();
while(k>'9'||k<'0'){if(k=='-')f=-1;k=getchar();}
while(k>='0'&&k<='9'){a=a*10+k-'0';k=getchar();}
a*=f;
return*this;
}
}in;
const int maxn=1005000;
//链式前向星
int h[maxn],fr[maxn],to[maxn],tot;
ivoid add(rint,rint);
//End
int n;
int fa[maxn],dep[maxn];
ll data[maxn];
int sign;
ll ans=0;
bool ed[maxn];
ll f[maxn][3];
void dp(rint,rint);
int main()
{
//freopen("data.in","r",stdin);
in,n;
for(rint i=1;i<=n;i++){
scanf("%lld",&data[i]);
in,fa[i];
add(fa[i],i);
}
for(rint i=1;i<=n;i++){
if(!ed[i]){
rint root=i;
ed[root]=1;
while(!ed[fa[root]]){
root=fa[root];
ed[root]=1;
}
dp(root,root);
ll tmax=f[root][0];
// for(rint i=1;i<=n;i++) printf("%d %d\n",f[i][0],f[i][1]);
if(root!=fa[sign]){
pair<ll,ll>now,pre;
now=make_pair(f[sign][0],f[sign][0]);
pre=make_pair(f[sign][0],f[sign][1]);
rint Now=fa[sign];
while(1)
{
pair<ll,ll> tmp=make_pair(f[Now][0],f[Now][1]);
f[Now][0]-=max(pre.xx,pre.yy);
f[Now][0]+=max(now.xx,now.yy);
f[Now][1]-=pre.xx;
f[Now][1]+=now.xx;
pre=tmp;
now=make_pair(f[Now][0],f[Now][1]);
if(Now==root) break;
Now=fa[Now];
}
}
ans+=max(f[root][1],max(tmax,f[root][0]));
}
}
printf("%lld\n",ans);
return 0;
}
void dp(rint pos,rint root)
{
f[pos][1]=data[pos];
dep[pos]=dep[fa[pos]]+1;
ed[pos]=1;
for(rint i=h[pos];i;i=fr[i]){
if(to[i]!=root){
dp(to[i],root);
f[pos][1]+=f[to[i]][0];
f[pos][0]+=max(f[to[i]][0],f[to[i]][1]);
}
else{
sign=pos;
}
}
}
ivoid add(rint u,rint v)
{
tot++;
fr[tot]=h[u];
to[tot]=v;
h[u]=tot;
}
Code(std)
#include <bits/stdc++.h>
#define LL long long
#define maxn 1000010
using namespace std;
struct Edge{
int to, next;
}edge[maxn << 1];
int num, head[maxn], a[maxn], f[maxn], vis[maxn], n;
LL dp[maxn][2], ans;
inline int read(){
int s= 0,w = 1;
char c = getchar();
for (; !isdigit(c); c = getchar()) if (c == '-') w = -1;
for (; isdigit(c); c = getchar()) s = (s << 1) + (s << 3) + (c ^ 48);
return s * w;
}
void addedge(int x, int y){ edge[++num] = (Edge){y, head[x]}, head[x] = num; }
void dfs(int u, int x){
vis[u] = 1;
dp[u][0] = 0, dp[u][1] = a[u];
for (int i = head[u]; i; i = edge[i].next){
int v = edge[i].to;
if (v != x){
dfs(v, x);
dp[u][0] += max(dp[v][0], dp[v][1]);
dp[u][1] += dp[v][0];
} else dp[v][1] = -1e9;
}
}
void solve(int u){
while (!vis[u]) vis[u] = 1, u = f[u];
dfs(u, u);
LL sum = max(dp[u][0], dp[u][1]);
u = f[u];
dfs(u, u);
ans += max(sum, max(dp[u][0], dp[u][1]));
}
int main(){
n = read();
for (int i = 1; i <= n; ++i){
a[i] = read(), f[i] = read();
addedge(f[i], i);
}
for (int i = 1; i <= n; ++i)
if (!vis[i]) solve(i);
printf("%lld\n", ans);
return 0;
}
Conclusion
有时候考虑一些简单的做法会有很大的好处。