题目
首先可以比较容易的写出dp的转移,化简后发现是含有i和j的乘积形式,所以可以想到斜率优化。化简后可以发现决策点就是(2*dis[j], dp[j]-dis[j]*dis[j]),斜率是-(f[i]+dis[i])。
但是决策的范围是在i的子树,并不是传统的[1,i-1]。所以使用dfs序将子树问题转化为区间问题。注意斜率优化的决策点的横坐标需要单调递增,所以每个线段树节点需要维护当前区间的节点,并按横坐标排序,用这些决策点维护一个斜率单调递减的序列。但是横坐标还可能相同,根据贪心,如果横坐标相同那就去dp值大的那个决策。可是dp值一开始并不能知道,所以需要在查询时更新维护。这样就维护出了区间的一个单调序列,对于一次查询,用线段树查询就分成了一些区间,这些区间取最大值即可。需要注意的是,这里的斜率并不单调,不可以使用单调队列来优化,这里需要采用二分,找到斜率第一个比它小的决策。
#include <iostream>
#include <vector>
#include <algorithm>
using namespace std;
typedef long long ll;
const int maxn = 1e5+5;
struct node {
int id;
ll val;
node(int a,ll b) {
id = a;
val = b;
}
};
vector<node> g[maxn];
ll v[maxn],f[maxn],dfn[maxn],num[maxn],siz[maxn],dis[maxn];
int cnt;
ll dp[maxn];
ll mod = 1e18;
double A(int i,int j) {
return (dp[i]-dis[i]*dis[i]) - (dp[j]-dis[j]*dis[j]);
}
double B(int i,int j) {
return 2*(dis[i] - dis[j]);
}
void dfs1(int x,ll d) {
cnt ++;
dfn[x] = cnt;
num[cnt] = x;
dis[x] = d;
siz[x] = 1;
for (int i = 0; i < g[x].size(); i++) {
int t = g[x][i].id;
dfs1(t,d+g[x][i].val);
siz[x] += siz[t];
}
}
struct nodex{
int l,r;
int flag;
vector<int> q;
vector<int> p;
}a[maxn*4];
void update(int x) {
int id1 = 0,id2 = 0;
int l = 2*x,r = 2*x+1;
while ( id1 < a[l].p.size() || id2 < a[r].p.size() ) {
int now;
if ( id1 == a[l].p.size() ) {
now = a[r].p[id2];
id2 ++;
}else if ( id2 == a[r].p.size() ) {
now = a[l].p[id1];
id1 ++;
}else {
if ( dis[a[l].p[id1]] < dis[a[r].p[id2]] ) {
now = a[l].p[id1];
id1 ++;
}else {
now = a[r].p[id2];
id2 ++;
}
}
a[x].p.push_back(now);
}
}
void build(int x,int l,int r) {
a[x].l = l,a[x].r = r;
a[x].flag = 0;
a[x].q.clear();
a[x].p.clear();
if ( l == r ) {
a[x].q.push_back(num[l]);
a[x].p.push_back(num[l]);
return;
}
int m = (l+r)>>1;
build(2*x,l,m),build(2*x+1,m+1,r);
update(x);
}
ll query(int x,int l,int r,int p) {
if ( a[x].l == l && a[x].r == r ) {
if ( a[x].flag == 0 ) {
for (int i = 0; i < a[x].p.size(); i++) {
int now = a[x].p[i];
int flagx = 0;
while ( a[x].q.size() > 1 ) {
int t1 = a[x].q.back();
a[x].q.pop_back();
int t2 = a[x].q.back();
if ( dis[now] == dis[t1] ) {
if ( dp[now] < dp[t1] ) {
flagx = 1;
a[x].q.push_back(t1);
break;
}
}else if ( A(now,t1) / B(now,t1) < A(t1,t2) / B(t1,t2) ) {
a[x].q.push_back(t1);
break;
}
}
if ( flagx == 0 ) a[x].q.push_back(now);
}
a[x].flag = 1;
}
int lx = 0,rx = a[x].q.size() - 2;
int res = a[x].q.size()-1;
while (lx <= rx) {
int mid = (lx+rx) >> 1;
int midx = mid + 1;
if ( A(a[x].q[midx],a[x].q[mid])/B(a[x].q[midx],a[x].q[mid]) < -(dis[p]+f[p]) ) {
res = mid;
rx = mid - 1;
}else lx = mid + 1;
}
res = a[x].q[res];
return max((ll)0,dp[res] + v[p] - (f[p]-(dis[res]-dis[p]))*(f[p]-(dis[res]-dis[p])));
}
int m = (a[x].l+a[x].r) >> 1;
if ( r <= m ) return query(2*x,l,r,p);
else if ( l > m ) return query(2*x+1,l,r,p);
else return max(query(2*x,l,m,p),query(2*x+1,m+1,r,p));
}
void dfs2(int x) {
if ( siz[x] == 1 ) {
dp[x] = 0;
return;
}
for (int i = 0; i < g[x].size(); i++) {
int t = g[x][i].id;
dfs2(t);
}
dp[x] = query(1,dfn[x]+1,dfn[x]+siz[x]-1,x);
dp[x] %= mod;
}
int main() {
ios::sync_with_stdio(false);
cin.tie(0),cout.tie(0);
int t;
cin >> t;
while ( t-- ) {
int n;
cin >> n;
cnt = 0;
for (int i = 1; i <= n; i++) {
g[i].clear();
dp[i] = 0;
}
for (int i = 1; i <= n; i++) {
int fa;
ll val;
cin >> fa >> val >> v[i] >> f[i];
if ( fa == 0 ) continue;
else g[fa].push_back(node(i,val));
}
dfs1(1,0);
build(1,1,n);
dfs2(1);
ll ans = 0;
for (int i = 1; i <= n; i++) {
ans += dp[i];
ans %= mod;
}
cout << ans << '\n';
}
return 0;
}