A. 【例题1】树上求和
Link
我们用
0
/
1
0/1
0/1表示一个点的选与不选
那么
f
[
i
]
[
1
]
=
f
[
i
]
[
1
]
+
f
[
t
o
]
[
0
]
f[i][1]=f[i][1]+f[to][0]
f[i][1]=f[i][1]+f[to][0],
f
[
i
]
[
0
]
=
f
[
i
]
[
0
]
+
m
a
x
(
f
[
t
o
]
[
1
]
,
f
[
t
o
]
[
0
]
)
f[i][0]=f[i][0]+max(f[to][1],f[to][0])
f[i][0]=f[i][0]+max(f[to][1],f[to][0])
动态方程就出来了
这一板块昨晚,就就发现一个树形dp的共同点:通过dfs对目标状态进行累加
Code
#include <iostream>
#include <cstdio>
using namespace std;
struct DT{
int y, next;
}a[6100];
int n, l, k, num, ans;
int s[6100], head[6100], t[6100], f[6100][2];
void dfs(int x) {
f[x][1] = s[x];
for(int i = head[x]; i; i = a[i].next) {
dfs(a[i].y);
f[x][0] = f[x][0] + max(f[a[i].y][0], f[a[i].y][1]);
f[x][1] = f[x][1] + f[a[i].y][0];
}
}
int main() {
scanf("%d", &n);
for(int i = 1; i <= n; i++)
scanf("%d", &s[i]);
for(int i = 1; i < n; i++) {
scanf("%d %d", &l, &k);
t[l]++;
a[++num] = (DT){l, head[k]};
head[k] = num;
}
for(int i = 1; i <= n; i++) {
if(!t[i]) dfs(i);
ans = max(ans, max(f[i][0], f[i][1]));
}
printf("%d", ans);
}
B. 【例题2】结点覆盖
Link
对于一个点,一共有三种选择:
- 选择自己
- 选择父亲
- 选择儿子
那么答案就是 m a x ( d p [ 1 ] [ 1 ] , d p [ 1 ] [ 2 ] ) max(dp[1][1],dp[1][2]) max(dp[1][1],dp[1][2])
考虑转移: 1. 当一点选择自己时自己的儿子不论怎么选都可以 so, d p [ i ] [ 1 ] + = m i n ( d p [ t o ] [ 0 ] , d p [ t o ] [ 1 ] , d p [ t o ] [ 2 ] ) dp[i][1]+=min(dp[to][0],dp[to][1],dp[to][2]) dp[i][1]+=min(dp[to][0],dp[to][1],dp[to][2])
2.当选择自己父亲时自己就不需要选了,所以 d p [ i ] [ 0 ] + = m i n ( d p [ t o ] [ 1 ] , d p [ t o ] [ 2 ] ) dp[i][0]+=min(dp[to][1],dp[to][2]) dp[i][0]+=min(dp[to][1],dp[to][2])
3.当选儿子时 d p [ i ] [ 2 ] = m i n ( d p [ t o ] [ 1 ] , d p [ t o ] [ 2 ] ) dp[i][2]=min(dp[to][1],dp[to][2]) dp[i][2]=min(dp[to][1],dp[to][2]),然而这是有纰漏的,因为我们不能保证选择的答案中包含 d p [ t o ] [ 1 ] dp[to][1] dp[to][1],于是乎我们就需要找到选则的元素与 d p [ t o ] [ 1 ] dp[to][1] dp[to][1]的最小值,以保证被儿子覆盖
Code
#include<bits/stdc++.h>
#define re register
#define inl inline
#define int long long
using namespace std;
int read(){
int sum=0,f=1;char c=getchar();
while(!isdigit(c)){if(c=='-') f=-1;c=getchar();}
while((isdigit(c))){sum=(sum<<3)+(sum<<1)+(c^48);c=getchar();}
return f*sum;
}
const int N=3010;
struct node{
int to,nxt;
}e[N];
int cnt,head[N],a[N],n;
int dp[N][4];
inl void add(int u,int v){
e[++cnt].to=v;
e[cnt].nxt=head[u];
head[u]=cnt;
}
inl void dfs(int u,int fa){
dp[u][1]=a[u];
int minn=1e18;
for(re int i=head[u];i;i=e[i].nxt){
int v=e[i].to;
if(v==fa) continue;
dfs(v,u);
dp[u][0]+=min(dp[v][1],dp[v][2]);
dp[u][1]+=min(dp[v][0],min(dp[v][1],dp[v][2]));
dp[u][2]+=min(dp[v][1],dp[v][2]);
minn=min(minn,dp[v][1]-min(dp[v][1],dp[v][2]));
}
dp[u][2]+=minn;
}
signed main(){
n=read();
for(re int i=1;i<=n;i++){
int u=read();a[u]=read();
int m=read();
for(re int j=1;j<=m;j++){
int v=read();
add(u,v),add(v,u);
}
}
dfs(1,0);
printf("%lld\n",min(dp[1][1],dp[1][2]));
return 0;
}
【例题3】最长距离
Link
问题很简单,就是问一个有根树每一个点再树上可以延申的最长距离
思路: 对于一个点,与其相距最远的点不是在其子树中,就是在其子树之外,于是可以先dfs一遍来在子树中找一个点离每一个根到最远距离
而子树外的情况就有两种:1,在自己祖先的其他子孙中找 2,在其祖先的祖先中找
对于第一种情况比较好想,因为所有点都已经处理完一遍了,我们直接用这个点到祖先的距离+另一个儿子最长延申距离即可
而第二种情况,就直接比较他祖先的外子树最远距离+点到祖先的距离
Code
#include<bits/stdc++.h>
#define re register
#define inl inline
#define int long long
using namespace std;
int read(){
int sum=0,f=1;char c=getchar();
while(!isdigit(c)){if(c=='-') f=-1;c=getchar();}
while((isdigit(c))){sum=(sum<<3)+(sum<<1)+(c^48);c=getchar();}
return f*sum;
}
const int N=1e5+10;
int f[N],g[N];
struct node{
int to,nxt,w;
}e[N<<1];
int cnt,head[N],n;
inl void add(int u,int v,int w){
e[++cnt].to=v;
e[cnt].nxt=head[u];
e[cnt].w=w;
head[u]=cnt;
}
inl void dfs1(int u){
for(re int i=head[u];i;i=e[i].nxt){
int v=e[i].to;
dfs1(v);
f[u]=max(f[u],e[i].w+f[v]);
}
}
inl void dfs2(int u,int fa){
int ww=0;
for(re int i=head[fa];i;i=e[i].nxt){
if(e[i].to==u) ww=e[i].w;
}
for(re int i=head[fa];i;i=e[i].nxt){
if(e[i].to==u) continue;
g[u]=max(g[u],ww+f[e[i].to]+e[i].w);
}
g[u]=max(g[u],g[fa]+ww);
for(re int i=head[u];i;i=e[i].nxt) dfs2(e[i].to,u);
}
signed main(){
while(scanf("%lld",&n)!=EOF){
memset(head,0,sizeof(head));
memset(f,0,sizeof(f));
memset(g,0,sizeof(g));
cnt=0;
for(re int i=2;i<=n;i++){
int v=read(),w=read();
add(v,i,w);
}
dfs1(1);
dfs2(1,0);
for(re int i=1;i<=n;i++){
printf("%lld\n",max(f[i],g[i]));
}
}
return 0;
}
D. 【例题4】选课方案
最基本的树上dp
分析,我们约定用
d
p
[
i
]
[
j
]
dp[i][j]
dp[i][j]来表示第
i
i
i门课与子树中
j
j
j门课(包括他自己)的最大学分
因为整个图不只一颗树,所以决定用超级源——答案就为
d
p
[
0
]
[
m
]
dp[0][m]
dp[0][m](m需要加一)
**dp过程:**对于一个点处理一条子树分支时,我们需要处理
d
p
[
i
]
[
2
−
m
]
dp[i][2-m]
dp[i][2−m]的值,但是子树中所有的点只能算一遍,所以问题就变为了一个01背包问题,需要处理的背包种类有
j
−
1
j-1
j−1种(不可以全处理完,因为必须包含根节点)
可得dp方程:
d
p
[
i
]
[
j
]
=
m
a
x
(
d
p
[
i
]
[
j
]
,
d
p
[
i
]
[
j
−
k
]
+
d
p
[
v
]
[
k
]
)
dp[i][j]=max(dp[i][j],dp[i][j-k]+dp[v][k])
dp[i][j]=max(dp[i][j],dp[i][j−k]+dp[v][k])
Code
#include<bits/stdc++.h>
#define re register
#define inl inline
#define int long long
using namespace std;
int read(){
int sum=0,f=1;char c=getchar();
while(!isdigit(c)){if(c=='-') f=-1;c=getchar();}
while((isdigit(c))){sum=(sum<<3)+(sum<<1)+(c^48);c=getchar();}
return f*sum;
}
const int N=310;
struct node{
int to,nxt;
}e[N<<1];
int cnt,head[N],n,m,a[N],dp[N][N];
bool vis[N];
inl void add(int u,int v){
e[++cnt].to=v;
e[cnt].nxt=head[u];
head[u]=cnt;
}
inl void solve(int u){
for(re int i=head[u];i;i=e[i].nxt){
solve(e[i].to);
}
for(re int i=head[u];i;i=e[i].nxt){
for(re int j=m;j;j--){
for(re int k=0;k<j;k++){
dp[u][j]=max(dp[u][j],dp[u][j-k]+dp[e[i].to][k]);
}
}
}
}
signed main(){
n=read(),m=read();m++;
for(re int i=1;i<=n;i++){
int u=read();dp[i][1]=read();
add(u,i);
}
solve(0);
printf("%d",dp[0][m]);
}
E. 1.路径求和
题目里拿到的树是一颗无根树,那么出入度为1的点就可以是叶节点
题目里要求所有有向路径的权值和,不如转化为所有边被算的次数;那么我们拿到处理一条边时,次数等于一个端点为根的子树种叶节点个数*另一个端点为根的子树包含点再反过来加一遍就可以了
Code
#include<bits/stdc++.h>
#define re register
#define inl inline
#define int long long
using namespace std;
int read(){
int sum=0,f=1;char c=getchar();
while(!isdigit(c)){if(c=='-') f=-1;c=getchar();}
while(isdigit(c)){sum=(sum<<3)+(sum<<1)+(c^48);c=getchar();}
return f*sum;
}
const int N=1e5+10;
struct node{
int to,nxt,w;
}e[N<<1];
int cnt,head[N],n,m;
int sum[N],siz[N],ans,du[N];
inl void add(int u,int v,int w){
e[++cnt].to=v;
e[cnt].nxt=head[u];
e[cnt].w=w;
head[u]=cnt;
}
inl void dfs1(int u,int fa){
siz[u]=1;
if(du[u]==1) sum[u]=1;
for(re int i=head[u];i;i=e[i].nxt){
int v=e[i].to;
if(v==fa) continue;
dfs1(v,u);
sum[u]+=sum[v];siz[u]+=siz[v];
}
}
inl void dfs2(int u,int fa){
for(re int i=head[u];i;i=e[i].nxt){
int v=e[i].to;
if(v==fa) continue;
dfs2(v,u);
ans=ans+e[i].w*siz[v]*(sum[1]-sum[v])+e[i].w*sum[v]*(siz[1]-siz[v]);
}
}
signed main(){
n=read(),m=read();
for(re int i=1;i<=m;i++){
int w=read(),u=read(),v=read();
add(u,v,w),add(v,u,w);du[u]++,du[v]++;
}
dfs1(1,0);
dfs2(1,0);
printf("%lld\n",ans);
return 0;
}
F. 2.树上移动
感觉比较好想(因为我都可以想出来):对于问题1,这个人一定会走回头路的,但是最后一条边只用走一遍即可,于是就可以求这个点延申下去的最大路即可
而对于问题二:两个人都要有走一遍的路,而这两条路就是整个图中的最长链,所以我们要维护一个点延申的最大值与非严格此最大值即可
Code
#include<bits/stdc++.h>
#define re register
#define inl inline
#define int long long
using namespace std;
int read(){
int sum=0,f=1;char c=getchar();
while(!isdigit(c)){if(c=='-') f=-1;c=getchar();}
while(isdigit(c)){sum=(sum<<3)+(sum<<1)+(c^48);c=getchar();}
return f*sum;
}
const int N=1e5+10;
struct node{
int to,w,nxt;
}e[N<<1];
int cnt,head[N],n,s,tot,len1[N],len2[N];
inl void add(int u,int v,int w){
e[++cnt].to=v;e[cnt].w=w;e[cnt].nxt=head[u];head[u]=cnt;
}
inl void dfs(int u,int fa){
for(re int i=head[u];i;i=e[i].nxt){
int v=e[i].to;
if(v==fa) continue;
dfs(v,u);
if(len1[v]+e[i].w>=len1[u]){
len2[u]=len1[u],len1[u]=len1[v]+e[i].w;
}
else if(len1[v]+e[i].w>len2[u]) len2[u]=len1[v]+e[i].w;
}
}
signed main(){
n=read(),s=read();
for(re int i=1;i<n;i++){
int u=read(),v=read(),w=read();
add(u,v,w),add(v,u,w);tot+=w*2;
}
dfs(s,0);printf("%lld\n",tot-len1[s]);
int maxn=0;
for(re int i=1;i<=n;i++) maxn=max(maxn,len1[i]+len2[i]);
printf("%lld",tot-maxn);
return 0;
}
G. 3.块的计数
刚开始我的思路和E差不多,就是对于每一个最大值点进行计算即可,但是我原先那样样算,会多情况,会重算,就挂了很久
正解是根据正难则反原理既然我们不能直接算连通块数目,那么我们就可以计算以一个点为根的总联通块数与不包含最大的连通块数
那么答案就为
∑
s
[
i
]
−
g
[
i
]
\sum s[i]-g[i]
∑s[i]−g[i],而对于
s
[
i
]
s[i]
s[i]的转移
s
[
i
]
=
∏
s
[
v
]
+
1
s[i]=\prod s[v]+1
s[i]=∏s[v]+1,而对于g[i],如果不是最大值,初值赋为1,反之为0
注:一个点也算连通块
Code
#include<bits/stdc++.h>
#define re register
#define inl inline
#define int long long
using namespace std;
int read(){
int sum=0,f=1;char c=getchar();
while(!isdigit(c)){if(c=='-') f=-1;c=getchar();}
while(isdigit(c)){sum=(sum<<3)+(sum<<1)+(c^48);c=getchar();}
return sum*f;
}
const int mod=998244353;
const int N=1e5+10;
struct node{
int to,nxt;
}e[N<<1];
int f[N],g[N],flag,pos,cnt,head[N],a[N],maxn=-1e18,c[N],tot,ans,n;
inl void add(int u,int v){
e[++cnt].to=v;e[cnt].nxt=head[u];head[u]=cnt;
}
inl void dfs(int u,int fa){
f[u]=1;
g[u]=(maxn!=a[u]);
for(re int i=head[u];i;i=e[i].nxt){
int v=e[i].to;
if(v==fa) continue;
dfs(v,u);
f[u]=f[u]*(f[v]+1)%mod;
g[u]=g[u]*(g[v]+1)%mod;
}
ans=(ans+f[u]-g[u]+mod)%mod;
}
signed main(){
n=read();
for(re int i=1;i<=n;i++){
a[i]=read();if(a[i]>maxn) maxn=a[i];
}
for(re int i=1;i<=n;i++){
if(a[i]==maxn){
c[++tot]=i;
}
}
for(re int i=1;i<n;i++){
int u=read(),v=read();
add(u,v),add(v,u);
}
dfs(1,0);
printf("%lld",ans);
}
H. 4.树的合并
其实合并后就两种情况:1,直径为原先树的最大直径 2,为相连两个点再树中延申的最长距离+1
求直径:我这里用的是了两次dfs的方法,
而两次dfs后再就找到了树的直径的两个端点后,再双端寻找最大值,就可以解决单个点在树种延申的最大距离
最后计算和可以将一棵树的点距排序,在二分处理即可
Code
#include<bits/stdc++.h>
#define re register
#define int long long
using namespace std;
int read(){
int sum=0,f=1;char c=getchar();
while(!isdigit(c)){if(c=='-') f=-1;c=getchar();}
while(isdigit(c)){sum=(sum<<3)+(sum<<1)+(c^48);c=getchar();}
return sum*f;
}
const int N=1e5+10;
struct node{
int to,nxt;
}e[N<<1],r[N<<1];
int la,ra,lb,rb;
int n,m,cnt1,cnt2,head1[N],head2[N],dis1[N],dis2[N],res=-1,x,dis11[N],dis22[N];
inline void add(int u,int v){
e[++cnt1].nxt=head1[u];e[cnt1].to=v;head1[u]=cnt1;
}
inline void add2(int u,int v){
r[++cnt2].nxt=head2[u];r[cnt2].to=v;head2[u]=cnt2;
}
inline void dfs1(int u,int fa){
for(re int i=head1[u];i;i=e[i].nxt){
int v=e[i].to;
if(v==fa) continue;
dis1[v]=dis1[u]+1;
if(res==-1||dis1[v]>dis1[res]) res=v;
dfs1(v,u);
}
}
inline void dfs11(int u,int fa){
for(re int i=head1[u];i;i=e[i].nxt){
int v=e[i].to;
if(v==fa) continue;
dis11[v]=dis11[u]+1;
dfs11(v,u);
}
}
inline void dfs2(int u,int fa){
for(re int i=head2[u];i;i=r[i].nxt){
int v=r[i].to;
if(v==fa) continue;
dis2[v]=dis2[u]+1;
if(res==-1||dis2[v]>dis2[res]) res=v;
dfs2(v,u);
}
}
inline void dfs22(int u,int fa){
for(re int i=head2[u];i;i=r[i].nxt){
int v=r[i].to;
if(v==fa) continue;
dis22[v]=dis22[u]+1;
dfs22(v,u);
}
}
int ans;
inline void work(){
dfs1(1,0);
la=res,res=-1;memset(dis1,0,sizeof(dis1));
dfs1(la,0);ra=res;res=-1;
dfs11(ra,0),res=-1;
for(re int i=1;i<=n;i++) dis1[i]=max(dis1[i],dis11[i]);
x=dis1[la];
dfs2(1,0);
lb=res;res=-1;memset(dis2,0,sizeof(dis2));
dfs2(lb,0);rb=res;res=-1;
dfs22(rb,0);
for(re int i=1;i<=m;i++) dis2[i]=max(dis2[i],dis22[i]);
x=max(x,dis2[lb]);
}
bool cmp(int x,int y){
return x>y;
}
int sum[N];
inline void work2(){
sort(dis1+1,dis1+1+n,cmp);
for(re int i=1;i<=n;i++) sum[i]=sum[i-1]+dis1[i];
for(re int i=1;i<=m;i++){
int l=0,r=n;
while(l<r){
int mid=(l+r+1)>>1;
if(dis1[mid]+dis2[i]+1>=x) l=mid;
else r=mid-1;
}
ans+=sum[l]+dis2[i]*l+l+x*(n-l);
}
}
signed main(){
n=read(),m=read();
for(re int i=1;i<n;i++){
int u=read(),v=read();
add(u,v),add(v,u);
}
for(re int i=1;i<m;i++){
int u=read(),v=read();add2(u,v),add2(v,u);
}
work();
work2();
printf("%lld",ans);
return 0;
}
I. 5.权值统计
很显然的树形dp,但是对于u来处理有两种情况:1.u为端点2.u在路径内
而u为端点的情况就是我们需要dp的
d
p
[
u
]
=
(
(
∑
d
p
[
v
]
)
+
1
)
∗
a
[
u
]
dp[u]=((\sum dp[v]) +1)*a[u]
dp[u]=((∑dp[v])+1)∗a[u]
如果u在路径内,答案就是
s
o
n
1
∗
s
o
n
2
∗
a
[
u
]
son1*son2*a[u]
son1∗son2∗a[u]而我们要算总和,因为
(
a
1
+
a
2
+
a
3
+
.
.
.
.
.
.
.
+
a
n
)
2
=
∑
i
n
a
i
2
+
2
∑
∑
a
i
∗
a
j
(a_1+a_2+a_3+.......+a_n)^2=\sum_i^na_i^2+2\sum\sum a_i*a_j
(a1+a2+a3+.......+an)2=∑inai2+2∑∑ai∗aj
因为我们有了
s
u
m
1
sum_1
sum1于是直接平方
−
s
u
m
i
-sum_i
−sumi再除以2,最后直接计算即可
Code
#include<bits/stdc++.h>
#define re register
#define int long long
using namespace std;
int read(){
int sum=0,f=1;char c=getchar();
while(!isdigit(c)){if(c=='-') f=-1;c=getchar();}
while(isdigit(c)){sum=(sum<<3)+(sum<<1)+(c^48);c=getchar();}
return sum*f;
}
const int N=1e5+10;
const int mod=10086;
struct node{
int to,nxt;
}e[N<<1];
int cnt,head[N],a[N],n;
int ans,dp[N];
inline void add(int u,int v){
e[++cnt].to=v;e[cnt].nxt=head[u];head[u]=cnt;
}
inline void dfs(int u,int fa){
int sum1=0,sum2=0,sum3=0;
for(re int i=head[u];i;i=e[i].nxt){
int v=e[i].to;
if(v==fa) continue;
dfs(v,u);
sum1+=dp[v];
sum2+=dp[v]*dp[v];
}
dp[u]=(sum1+1)*a[u]%mod;
sum3=(sum1*sum1-sum2)/2%mod;
ans=(ans+dp[u]+sum3*a[u]+mod)%mod;
}
signed main(){
n=read();
for(re int i=1;i<=n;i++) a[i]=read();
for(re int i=1;i<n;i++){
int u=read(),v=read();
add(u,v),add(v,u);
}
dfs(1,0);
printf("%lld",ans);
return 0;
}