主要是在写换根dp时写得脑阔疼
一.求树上一个点到其他节点的最大距离
转换一下就等价于求以哪一个节点为根时,这棵树的所有节点深度之和最大
首先以节点1为根跑一遍dfs1,目的是更新f[N]数组,N为总节点数
其中f[x]表示以1为根时,节点x的子树的所有节点深度之和
设dp[x]为以x为根的树的所有节点深度之和
由于我们一开始是以1为根,那显然dp[1] = f[1]
我们进行第二遍dfs2,目的是更新dp数组
假设当前节点为x,且dp[x]已经处理完了
对于x的一个孩子v,经过推导有:
dp[v] = dp[x] - (f[v] + size[v]) + (size[x] - size[v]) + f[v];
用语言来讲,就是
以v为根的所有节点深度之和,等于原来以1为根时v的子树深度之和,加上以v父亲为根时,除v所在子树以外的子树的深度之和(根不一样,显然需要更新一下)
看上去是有点绕,建议画个图看看能加深理解
Code:
#include<iostream>
#include<cstdio>
using namespace std;
const int N = 1000010 , M = 2000010;
int h[N] , to[M] , nxt[M] , tot = 1;
long long n , f[N] , dep[N] , size[N];//以x为根的子树深度之和
/*
size[x]在第一遍dfs中表示以1为根时,x的子树上的总节点数
第二遍dfs每次转移完后,size[x]表示以x为根的树上的总节点数(其实是个定值n,懒得改了)
*/
long long ans , dp[N] , k = 1;
inline void add(int a , int b)
{
to[++tot] = b ; nxt[tot] = h[a] ; h[a] = tot;
}
void dfs1(int x , int lst)
{
size[x] = 1;
for(int i = h[x] , v ; v = to[i] , i ; i = nxt[i])
{
if(v == lst) continue;
dep[v] = dep[x] + 1;
dfs1(v , x);
f[x] += f[v] + size[v];
size[x] += size[v];
}
// f[x]++;//自身
}
void dfs2(int x , int lst)
{
if(x == 1)
ans = f[x] , dp[x] = f[x];
else
if(ans < dp[x])
{
ans = dp[x] ; k = x;
}
for(int i = h[x] , v ; v = to[i] , i ; i = nxt[i])
{
if(v == lst) continue;
dp[v] = dp[x] - (f[v] + size[v]) + (size[x] - size[v]) + f[v];
size[v] = size[x];
dfs2(v , x);
}
}
int main()
{
cin >> n;
for(int i = 1 , u , v ; i < n ; i++)
{
cin >> u >> v;
add(u ,v) ; add(v , u);
}
dep[1] = 1;
dfs1(1 , 1);
dfs2(1 , 1);
cout << k;
return 0;
}
至于更新的方法,这里就不再赘述了 (其实就是懒)
二.不会总结…
咳咳,其实跟第一题差不多,就是多了个点权
那我们f数组跟dp数组也更新一下就好了:
现在设f[x]为以1为根,x的子树上所有节点到x的总不方便值
dp[x]为以x为根时所有节点到根的总不方便值
其他照旧
更新dp数组的方法也改一下,于是:
dp[v] = dp[x] - (f[v] + sum[v] * w[i]) + (sum[x] - sum[v]) * w[i] + f[v];
完整Code:
#include<iostream>
#include<cstdio>
using namespace std;
typedef long long ll;
const int N = 1000010 , M = 2000010 ;
const ll inf = 1e9 + 1;
ll to[M] , h[N] , nxt[N] , w[N] , tot = 1;
ll n , C[N];
ll sum[N] , size[N] , f[N] , dp[N] , ans = inf;
inline void add(int a , int b , int c)
{
to[++tot] = b ; w[tot] = c ; nxt[tot] = h[a] ; h[a] = tot;
}
void dfs1(int x , int lst)
{
sum[x] = C[x];
for(int i = h[x] , v ; v = to[i] , i ; i = nxt[i])
{
if(v == lst) continue;
dfs1(v , x);
sum[x] += sum[v];
f[x] += f[v] + sum[v] * w[i];
}
}
void dfs2(int x , int lst)
{
if(x == 1)
{
ans = f[x] ; dp[x] = f[x];
}
for(int i = h[x] , v ; v = to[i] , i ; i = nxt[i])
{
if(v == lst) continue;
dp[v] = dp[x] - (f[v] + sum[v] * w[i]) + (sum[x] - sum[v]) * w[i] + f[v];
sum[v] = sum[x];
ans = min(ans , dp[v]);
dfs2(v , x);
}
}
int main()
{
scanf("%lld" , &n);
for(int i = 1 ; i <= n ; i++)
scanf("%lld" , &C[i]);
for(ll i = 1 , u , v , c ; i < n ; i++)
{
scanf("%lld%lld%lld" , &u , &v , &c);
add(u , v , c) ; add(v , u , c);
}
dfs1(1 , 1);
dfs2(1 , 1);
cout << ans;
return 0;
}
三.对于每个节点求出距离它不超过 k 的所有节点权值和
还是蛮好想的
设f[x][j]为以1为根时,在x字数上与节点x相隔j条边的总点权,特别地令f[x][0]为它自身点权
dp[x][j]则表示以x为根时,与x相隔j条边的总点权,特别地也令dp[x][0]为自身点权
状态设好了,再想想转移方程
除了对j=0情况的特判外
显然有:
dp[v][j] = dp[x][j - 1] + f[v][j]
但是,这是错的
原因是 dp[x][j - 1] 中包含了 f[v][j - 2] 的点权(Q:为啥是j - 2?A:建议画图)
所以最终我们还要减掉
最终版:
dp[v][j] = dp[x][j - 1] + f[v][j] - f[v][j - 2]
于是j == 0 跟 j == 1要特地拎出来特判
Code:
#include<iostream>
#include<cstdio>
using namespace std;
const int N = 100010 , K = 21 , M = 200010;
int h[N] , to[M] , nxt[M] , tot = 1;
int n , k ;
int f[N][K] , w[N] , dp[N][K] , ans;
inline void add(int a , int b)
{
to[++tot] = b; nxt[tot] = h[a] ; h[a] = tot;
}
void dfs1(int x , int lst)
{
f[x][0] = w[x];
for(int i = h[x] , v ; v = to[i] , i ; i = nxt[i])
{
if(v == lst) continue;
dfs1(v , x);
for(int j = 1 ; j <= k ; j++)
f[x][j] += f[v][j - 1];
}
}
void dfs2(int x , int lst)
{
if(x == 1)
{
for(int j = 0 ; j <= k ; j++)
dp[x][j] = f[x][j];
}
for(int i = h[x] , v ; v = to[i] , i ; i = nxt[i])
{
if(v == lst) continue;
dp[v][0] = w[v];
dp[v][1] = dp[x][0] + f[v][1];
for(int j = 2 ; j <= k ; j++)
dp[v][j] = dp[x][j - 1] - f[v][j - 2] + f[v][j];
dfs2(v , x);
}
}
int main()
{
scanf("%d%d" , &n , &k);
for(int i = 1 , u , v ; i < n ; i++)
{
scanf("%d%d" , &u , &v);
add(u , v) ; add(v , u);
}
for(int i = 1 ; i <= n ; i++)
cin >> w[i];
dfs1(1 , 1);
dfs2(1 , 1);
for(int i = 1 ; i <= n ; i++)
{
int res = 0;
for(int j = 0 ; j <= k ; j++)
res += dp[i][j];
printf("%d\n" , res);
}
return 0;
}
有空再更新吧