题目大意:
给一颗树,每个点有点权w。你可以选择一个不超过T的非负整数C,然后给所有点的点权+C,然后所有点的点权对P取模。
这之后你要选择若干点不相交的链,假设链上的点权和S,选了k条,那么收益是
S
k
+
1
\frac{S}{k+1}
k+1S。求最大收益。
n
≤
5000
,
P
≤
1
0
5
n\le5000,P\le10^5
n≤5000,P≤105
题解:考虑确定C之后可以二分然后dp。
发现有用的C是O(n)的,即
{
P
−
1
−
w
i
,
T
}
\{P-1-w_i,T\}
{P−1−wi,T}
然后考虑设
F
[
C
]
F[C]
F[C]表示确定C的答案,那么相当于是找这个数组中的最大值。
然后确定一个位置需要
O
(
n
log
v
)
O(n\log v)
O(nlogv)的时间,直接做是
O
(
n
2
log
v
)
O(n^2\log v)
O(n2logv)的,但是注意到可以在
O
(
n
)
O(n)
O(n)的时间内判定
F
[
C
]
F[C]
F[C]和某个数的大小关系,因此将需要计算的C随机排列后,每次只需要计算比当前最大值更大的
F
[
C
]
F[C]
F[C]的值。在一个序列每次找下一个比当前数字大的期望是log次的。因此复杂度变为期望
O
(
n
2
+
n
log
2
v
)
O(n^2+n\log^2v)
O(n2+nlog2v)
#include<bits/stdc++.h>
#define rep(i,a,b) for(int i=a;i<=b;i++)
#define Rep(i,v) rep(i,0,(int)v.size()-1)
#define lint long long
#define ull unsigned lint
#define db double
#define pb push_back
#define mp make_pair
#define fir first
#define sec second
#define gc getchar()
#define debug(x) cerr<<#x<<"="<<x
#define sp <<" "
#define ln <<endl
using namespace std;
typedef pair<int,int> pii;
typedef set<int>::iterator sit;
inline int inn()
{
int x,ch;while((ch=gc)<'0'||ch>'9');
x=ch^'0';while((ch=gc)>='0'&&ch<='9')
x=(x<<1)+(x<<3)+(ch^'0');return x;
}
const db eps=5e-8,epsd=eps/2,inf=-6e9;
const int N=5010;
struct edges{
int to,pre;
}e[N<<1];int h[N],etop,tms[N],dfc,fa[N],v[N],nc[N],val[N];db f[N][3],tmp[3],g[N];
inline int add_edge(int u,int v) { return e[++etop].to=v,e[etop].pre=h[u],h[u]=etop; }
inline int dfs(int x)
{
tms[++dfc]=x;for(int i=h[x],y;i;i=e[i].pre) if((y=e[i].to)^fa[x]) fa[y]=x,dfs(y);return 0;
}
inline int check(db t,int n)
{
for(int qwq=n;qwq;qwq--)
{
int x=tms[qwq];f[x][0]=0,f[x][1]=val[x]-t,f[x][2]=inf;
for(int i=h[x],y;i;i=e[i].pre) if((y=e[i].to)^fa[x])
tmp[0]=f[x][0]+g[y],
tmp[1]=max(f[x][1]+g[y],f[x][0]+f[y][1]+val[x]),
tmp[2]=max(f[x][2]+g[y],f[x][1]+f[y][1]+t),
f[x][0]=tmp[0],f[x][1]=tmp[1],f[x][2]=tmp[2];
g[x]=max(max(f[x][0],f[x][1]),f[x][2]);
}
return g[1]>t;
}
int main()
{
int n=inn(),p=inn(),x,y;rep(i,1,n) v[i]=inn(),nc[i]=p-1-v[i];
rep(i,1,n-1) x=inn(),y=inn(),add_edge(x,y),add_edge(y,x);
int m=0,t=inn();rep(i,1,n) if(nc[i]<=t) nc[++m]=nc[i];
nc[++m]=t,sort(nc+1,nc+m+1),m=int(unique(nc+1,nc+m+1)-nc-1);
random_shuffle(nc+1,nc+m+1);db ans=0;dfs(1);
rep(i,1,m)
{
rep(j,1,n) val[j]=v[j]+nc[i],(val[j]>=p?val[j]-=p:0);
if(!check(ans,n)) continue;db L=ans,R=(db)n*p/2;
while(L+eps<=R)
{
db mid=(L+R)/2;
if(check(mid,n)) ans=mid,L=mid+epsd;
else R=mid-epsd;
}
}
return !printf("%.6lf\n",(double)ans);
}