题意:给出一棵数,树上有些点有商店,要你建一个商店使你可以得到最多的顾客,顾客选择去离他最近的商店(距离相同去编号小的)。
首先预处理一下所有点离他最近的商店,记录下来编号和距离,用一个pair就好了。first记录距离,second记录编号,排序也刚好符合题意。。预处理感觉和spfa差不多那样弄一下(其他方法我不会。。)。。
然后就是点分治了,找到每一次重心后,求出到这个点的距离,用dis[u]表示,然后假定必须经过重心,所以两点的距离就是dis[u]+dis[v],如果要别人到自己商店来,就要满足dis[u]+dis[v]<=f[v].first(f[v]为v点到初始离他距离最近的商店),所以dis[u]<=f[v],first-dis[v],这个过程,我们可以把这个树上的f[v].first-dis[v]加入一个数组中,然后排个序,然后就可以枚举我们要建的商店的位置然后二分直接求答案了(这里也要注意编号最小,在这WA了好多发),有可能在同一条支路上,后面按照一般点分治的方法处理好就是了。。
#pragma comment(linker, "/STACK:102400000,102400000")
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<queue>
using namespace std;
const int MAXN=100010;
const int INF=0x3f3f3f3f;
typedef pair<int,int> Point;
struct EDGE
{
int v,next;
int dist;
}edge[MAXN<<1];
int head[MAXN],size;
void init()
{
memset(head,-1,sizeof(head));
size=0;
}
void add_edge(int u,int v,int c)
{
edge[size].v=v;
edge[size].dist=c;
edge[size].next=head[u];
head[u]=size++;
}
Point f[MAXN];
queue<int> q;
bool vis[MAXN];
void bfs()
{
while(!q.empty())
{
int u=q.front();
q.pop();
vis[u]=0;
for(int i=head[u];i!=-1;i=edge[i].next)
{
int v=edge[i].v;
Point temp;
temp.first=f[u].first+edge[i].dist;
temp.second=f[u].second;
if(temp<f[v])
{
f[v]=temp;
if(!vis[v])
{
vis[v]=1;
q.push(v);
}
}
}
}
}
int siz[MAXN],num[MAXN],tot_size,root;
void get_root(int u,int fa)
{
siz[u]=1;
num[u]=0;
for(int i=head[u];i!=-1;i=edge[i].next)
{
int v=edge[i].v;
if(vis[v]||v==fa)
continue;
get_root(v,u);
siz[u]+=siz[v];
num[u]=max(num[u],siz[v]);
}
num[u]=max(num[u],tot_size-num[u]);
if(num[root]>num[u])
root=u;
}
void get_size(int u,int fa)
{
siz[u]=1;
for(int i=head[u];i!=-1;i=edge[i].next)
{
int v=edge[i].v;
if(v==fa||vis[v])
continue;
get_size(v,u);
siz[u]+=siz[v];
}
}
int dep[MAXN],cnt,b[MAXN],ans[MAXN];
Point a[MAXN];
bool ism[MAXN];
void get_dep(int u,int fa)
{
for(int i=head[u];i!=-1;i=edge[i].next)
{
int v=edge[i].v;
if(v==fa||vis[v])
continue;
dep[v]=dep[u]+edge[i].dist;
if(!ism[v])
{
b[cnt]=v;
a[cnt].first=f[v].first-dep[v];
a[cnt++].second=f[v].second;
}
get_dep(v,u);
}
}
int findx(Point val)
{
int l=0,r=cnt-1;
while(l<=r)
{
int mid=(l+r)>>1;
if(a[mid]<=val)
l=mid+1;
else
r=mid-1;
}
return l;
}
void get_num(int u,int val,int mark)
{
dep[u]=val;
cnt=0;
if(!ism[u])
{
b[cnt]=u;
a[cnt].first=f[u].first-dep[u];
a[cnt++].second=f[u].second;
}
get_dep(u,-1);
sort(a,a+cnt);
for(int i=0;i<cnt;i++)
{
if(ism[b[i]])
continue;
if(!mark)
ans[b[i]]+=cnt-findx(make_pair(dep[b[i]],b[i]));
else
ans[b[i]]-=cnt-findx(make_pair(dep[b[i]],b[i]));
}
}
void dfs(int u)
{
vis[u]=1;
get_num(u,0,0);
for(int i=head[u];i!=-1;i=edge[i].next)
{
int v=edge[i].v;
if(vis[v])
continue;
get_num(v,edge[i].dist,1);
}
for(int i=head[u];i!=-1;i=edge[i].next)
{
int v=edge[i].v;
if(vis[v])
continue;
root=0;
get_size(v,-1);
tot_size=siz[v];
get_root(v,-1);
dfs(root);
}
}
int main()
{
int n,i;
while(scanf("%d",&n)==1)
{
init();
int u,v,c;
for(i=1;i<n;i++)
{
scanf("%d%d%d",&u,&v,&c);
add_edge(u,v,c);
add_edge(v,u,c);
}
memset(ism,0,sizeof(ism));
for(i=1;i<=n;i++)
{
scanf("%d",&c);
if(c)
ism[i]=1;
}
for(i=1;i<=n;i++)
{
if(ism[i])
{
f[i].first=0;
f[i].second=i;
q.push(i);
vis[i]=1;
}
else
{
f[i].first=INF;
f[i].second=0;
vis[i]=0;
}
}
bfs();
memset(vis,0,sizeof(vis));
memset(ans,0,sizeof(ans));
memset(num,0,sizeof(num));
root=0;
num[root]=INF;
tot_size=n;
get_root(1,-1);
dfs(root);
int aans=0;
for(i=1;i<=n;i++)
aans=max(aans,ans[i]);
printf("%d\n",aans);
}
return 0;
}