感觉要是值域范围是1000以内, 感觉还是能写出来的。。
考虑dp[ i ][ j ]表示从 1 到 i 路径上的值是 j , i 这棵子树的最大贡献值。
然后可以发现 j 这维可以离散化, 离散化之后最优值不会变, 然后dp一遍就好了。
#pragma GCC optimize(2) #pragma GCC optimize(3) #include<bits/stdc++.h> #define LL long long #define LD long double #define ull unsigned long long #define fi first #define se second #define mk make_pair #define PLL pair<LL, LL> #define PLI pair<LL, int> #define PII pair<int, int> #define SZ(x) ((int)x.size()) #define ALL(x) (x).begin(), (x).end() #define fio ios::sync_with_stdio(false); cin.tie(0) ; using namespace std; const int N = 1000 + 7; const int inf = 0x3f3f3f3f; const LL INF = 0x3f3f3f3f3f3f3f3f; const int mod = (int)1e9 + 7; const double eps = 1e-8; const double PI = acos(-1); template<class T, class S> inline void add(T &a, S b) {a += b; if(a >= mod) a -= mod;} template<class T, class S> inline void sub(T &a, S b) {a -= b; if(a < 0) a += mod;} template<class T, class S> inline bool chkmax(T &a, S b) {return a < b ? a = b, true : false;} template<class T, class S> inline bool chkmin(T &a, S b) {return a > b ? a = b, true : false;} //mt19937 rng(chrono::steady_clock::now().time_since_epoch().count()); int n, m; int hs[N], hscnt; int ans[N]; LL dp[N][N]; PLI sufMax[N][N]; vector<PII> G[N]; vector<int> V[N]; void dfs(int u, int fa) { for(int i = 1; i <= hscnt; i++) { dp[u][i] = 0; } for(auto &e : G[u]) { if(e.se == fa) continue; dfs(e.se, u); } for(int i = hscnt, v, pt = 0; i >= 1; i--) { while(pt < SZ(V[u]) && V[u][pt] >= hs[i]) pt++; dp[u][i] += 1LL * hs[i] * pt; } for(int i = 1; i <= hscnt; i++) { sufMax[u][i] = mk(dp[u][i], i); } for(int i = hscnt - 1; i >= 1; i--) { chkmax(sufMax[u][i], sufMax[u][i + 1]); } for(int i = 1; i <= hscnt; i++) { dp[fa][i] += sufMax[u][i].fi; } } void getPath(int u, int p, int fa) { for(auto &e : G[u]) { int id = e.fi, v = e.se; if(v == fa) continue; ans[id] = hs[sufMax[v][p].se] - hs[p]; getPath(v, sufMax[v][p].se, u); } } void init() { hscnt = 0; for(int i = 1; i <= n; i++) { V[i].clear(); G[i].clear(); } } int main() { int T; scanf("%d", &T); while(T--) { scanf("%d%d", &n, &m); init(); for(int i = 1; i < n; i++) { int u, v; scanf("%d%d", &u, &v); G[u].push_back(mk(i, v)); G[v].push_back(mk(i, u)); } hs[++hscnt] = 0; for(int i = 1; i <= m; i++) { int c, b; scanf("%d%d", &c, &b); V[c].push_back(b); hs[++hscnt] = b; } sort(hs + 1, hs + 1 + hscnt); hscnt = unique(hs + 1, hs + 1 + hscnt) - hs - 1; for(int i = 1; i <= n; i++) { sort(V[i].rbegin(), V[i].rend()); } dfs(1, 0); printf("%lld\n", dp[1][1]); getPath(1, 1, 0); for(int i = 1; i < n; i++) { printf("%d%c", ans[i], " \n"[i == n - 1]); } } return 0; } /* */