场上瞎写暴力乱搞结果拿了一血(汗颜...不过还是得出题人谢罪)
考虑按点权从小到大依次插入每个点构建新树,发现每个点u能向任意一个a[v] < a[u]的v点连边,产生贡献为a[v] * dis(u, v)
然后我直接暴力按点权由小到大从每个点更新其他点,剪了点枝就过了...(想了一下应该随便构造条链卡下就能卡掉)
正解:
点分+李超树
李超树是维护线段的线段树,线段的定义是一些y=kx + b的直线限制其横坐标取值在一个区间,李超树可以实现nlog^2地维护某个横坐标x处y最高的线段y值(如果每个线段的横坐标定义域都是整个定义域,则复杂度为nlogn)。具体实现网上有很多。
这道题就可以点分治每次找到重心x,然后考察每个点u经过重心的所有路径(u - x - v)产生的最小贡献(u, v)然后更新u的答案。具体的就是贡献转化为(dis(u, x) + dis(x, v)) * a[v] = a[v] * dis(u, x) + a[v] * dis(x, v),x已经确定,显然可以把a[v]视为斜率,a[v] * dis(x, v)视为截距,用李超树做。
为了保证复杂度,先求出当前分治层中所有点到中心距离,然后按关键字a从小到大排序,边询问边插入。
(可能当前得到的不是简单路径,但显然因为是取min之后分治下去会被简单路径的答案覆盖)
代码:
//#define LOCAL
#include<bits/stdc++.h>
#define pii pair<int,int>
#define fi first
#define sc second
#define pb push_back
#define ll long long
#define trav(v,x) for(auto v:x)
#define all(x) (x).begin(), (x).end()
#define VI vector<int>
#define VLL vector<ll>
#define pll pair<ll, ll>
#define double long double
//#define int long long
using namespace std;
const int N = 1e6 + 100;
const int inf = 1e9;
//const ll inf = 1e18
const ll mod = 998244353;//1e9 + 7
#ifdef LOCAL
void debug_out(){cerr << endl;}
template<typename Head, typename... Tail>
void debug_out(Head H, Tail... T)
{
cerr << " " << to_string(H);
debug_out(T...);
}
#define debug(...) cerr << "[" << #__VA_ARGS__ << "]:", debug_out(__VA_ARGS__)
#else
#define debug(...) 42
#endif
struct Line {
ll k, b;
ll operator () (ll x)
{
return k * x + b;
}
Line (){};
Line (ll k, ll b):k(k), b(b){};
};
const int m = 2e5 + 100;
struct lc_tree {
#define mid ((l + r) >> 1)
#define ls (k << 1)
#define rs (k << 1 | 1)
Line seg[m + m + 100];
bool hav[m + m + 100];
VI bin;
void ins(Line x, int L, int R, int l = 1, int r = m, int k = 1)
{
if(l > R || r < L)
return;
if(L <= l && r <= R)
{
if(!hav[k])
return (void)(hav[k] = 1, seg[k] = x, bin.pb(k));
if(seg[k](mid) > x(mid))
swap(seg[k], x);
if(x(l) < seg[k](l))
ins(x, L, R, l, mid, ls);
if(x(r) < seg[k](r))
ins(x, L, R, mid + 1, r, rs);
return;
}
ins(x, L, R, l, mid, ls);
ins(x, L, R, mid + 1, rs);
return;
}
ll ask(ll x, int l = 1, int r = m, int k = 1)
{
if(!hav[k])
return 1e18;
if(l == r)
return seg[k](x);
if(x <= mid)
return min(seg[k](x), ask(x, l, mid, ls));
else
return min(seg[k](x), ask(x, mid + 1, r, rs));
}
void init()
{
memset(hav, 0, sizeof hav);
bin.clear();
}
void clear()
{
trav(v, bin)
hav[v] = 0;
bin.clear();
}
}lc;
int n;
ll a[N];
VI adj[N];
int rt, tot, sz[N], mn_sz;
bool vis[N];
void dfs(int x, int ff)
{
int mx = 0;
sz[x] = 1;
trav(v, adj[x])
{
if(vis[v] || v == ff)
continue;
dfs(v, x);
mx = max(mx, sz[v]);
sz[x] += sz[v];
}
mx = max(mx, tot - sz[x]);
if(mx < mn_sz)
rt = x, mn_sz = mx;
}
void find_rt(int x)
{
rt = x;
mn_sz = 1e9;
dfs(x, 0);
}
VI buk;
ll dis[N], ans[N];
void work(int x, int dd, int ff)
{
buk.pb(x);
dis[x] = dd;
trav(v, adj[x])
{
if(vis[v] || v == ff)
continue;
work(v, dd + 1, x);
}
}
void sol(int x)
{
lc.clear();
buk.clear();
vis[x] = 1;
work(x, 0, 0);
sort(all(buk), [](int x, int y)
{
return a[x] < a[y];
});
for(int lp = 0, rp; lp < buk.size(); lp = rp + 1)
{
rp = lp;
while(a[buk[lp]] == a[buk[rp]] && rp < buk.size()) ++rp;
--rp;
for(int i = lp; i <= rp; i++)
{
int v = buk[i];
ans[v] = min(ans[v], lc.ask(dis[v] + 1));
}
for(int i = lp; i <= rp; i++)
{
int v = buk[i];
lc.ins(Line(a[v], dis[v] * a[v]), 1, m);
}
}
trav(v, adj[x])
{
if(vis[v])
continue;
tot = sz[v], find_rt(v);
sol(rt);
}
}
void mian()
{
cin >> n;
for(int i = 1; i <= n; i++)
{
cin >> a[i];
}
for(int i = 1; i < n; i++)
{
int x, y;
cin >> x >> y;
adj[x].pb(y);
adj[y].pb(x);
}
lc.init();
memset(ans, 63, sizeof ans);
tot = n;
find_rt(n);
sol(rt);
ll res = 0;
for(int i = 1; i <= n; i++)
{
if(ans[i] >= 1e18)
continue;
//cerr << ans[i] << '\n';
res += ans[i] + a[i];
}
cout << res << '\n';
}
signed main()
{
ios::sync_with_stdio(0);
cin.tie(0);
mian();
}