题目分析:终于写出来了。。。1A。。
首先我们可以预处理出每个不是mart的点x所属与的mart的编号idx[x]以及到该mart的距离d[x],用一个结构体表示。
该部分最短路算法即可。
我们定义结构体变量x小于y即d[x]<d[y] ||d[x] == d[y]&&idx[x]<idx[y]。
然后便是树分治的舞台了~
首先找到树的重心,然后以该重心dfs一次得到所有点x到重心的距离dis[x]以及这棵树上的节点总数cnt,将d[x]-dis[x]以及idx[x]组成结构体丢进数组S,dfs以后对数组S排序,然后二分小于等于dis[x]的数目tmp,则该子树上能被x占领的数量为cnt-tmp。
为什么是cnt-tmp呢?因为如果x要占领y,则dis[x]+dis[y]<d[y]才行,所以有dis[x]<d[y]-dis[y],因为dis[x]>=d[y]-dis[y]的数目即tmp,节点总数即cnt,所以子树内的贡献为cnt-tmp。
并且,由于可能有的贡献是来自同一棵子树的,我们需要将其去重。
接下来就是递归求解了。
具体部分可以见代码,代码还是比较清晰的。
代码如下:
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <cmath>
using namespace std ;
typedef long long LL ;
#pragma comment(linker, "/STACK:16777216")
#define Log( i , a , b ) for ( int i = a ; ( 1 << i ) <= b ; ++ i )
#define rep( i , a , b ) for ( int i = a ; i < b ; ++ i )
#define For( i , a , b ) for ( int i = a ; i <= b ; ++ i )
#define rev( i , a , b ) for ( int i = ( a ) ; i >= ( b ) ; -- i )
#define travel( e , H , u ) for ( Edge* e = H[u] ; e ; e = e -> next )
#define clr( a , x ) memset ( a , x , sizeof a )
#define cpy( a , x ) memcpy ( a , x , sizeof a )
const int MAXN = 100005 ;
const int MAXE = 200005 ;
const int INF = 0x3f3f3f3f ;
struct Edge {
int v , c ;
Edge* next ;
} E[MAXE] , *H[MAXN] , *edge ;
struct Node {
int dis ;
int idx ;
Node () {}
Node ( int dis , int idx ) : dis ( dis ) , idx ( idx ) {}
bool operator < ( const Node& a ) const {
if ( dis != a.dis ) return dis < a.dis ;
return idx < a.idx ;
}
bool operator > ( const Node& a ) const {
return a < *this ;
}
bool operator <= ( const Node& a ) const {
return !( a < *this ) ;
}
bool operator >= ( const Node& a ) const {
return !( *this < a ) ;
}
Node operator + ( const int& val ) const {
return Node ( dis + val , idx ) ;
}
Node operator - ( const int& val ) const {
return Node ( dis - val , idx ) ;
}
} node[MAXN] , S[MAXN] ;
int Q[MAXN] , head , tail ;
int vis[MAXN] ;
int Time ;
int siz[MAXN] ;
int num[MAXN] ;
int mart[MAXN] ;
int ans[MAXN] ;
int dis[MAXN] ;
int size ;
int root ;
int cnt ;
int n ;
void clear () {
edge = E ;
num[0] = n ;
clr ( H , 0 ) ;
head = tail = 0 ;
clr ( ans , 0 ) ;
}
void addedge ( int u , int v , int c ) {
edge->v = v ;
edge->c = c ;
edge->next = H[u] ;
H[u] = edge ++ ;
}
void spfa () {
while ( head != tail ) {
int u = Q[head ++] ;
if ( head == MAXN ) head = 0 ;
vis[u] = Time - 1 ;
travel ( e , H , u ) {
int v = e->v ;
Node tmp = node[u] + e->c ;
if ( node[v] > tmp ) {
node[v] = tmp ;
if ( vis[v] != Time ) {
vis[v] = Time ;
Q[tail ++] = v ;
if ( tail == MAXN ) tail = 0 ;
}
}
}
}
}
void get_siz ( int u , int fa = 0 ) {
siz[u] = 1 ;
travel ( e , H , u ) {
int v = e->v ;
if ( v != fa && vis[v] != Time ) {
get_siz ( v , u ) ;
siz[u] += siz[v] ;
}
}
}
void get_root ( int u , int fa = 0 ) {
num[u] = 0 ;
travel ( e , H , u ) {
int v = e->v ;
if ( v != fa && vis[v] != Time ) {
get_root ( v , u ) ;
num[u] = max ( num[u] , siz[v] ) ;
}
}
num[u] = max ( num[u] , size - siz[u] ) ;
if ( num[u] < num[root] ) root = u ;
}
void get_dis ( int u , int val , int fa = 0 ) {
dis[u] = val ;
S[++ cnt] = node[u] - val ;//丢进S数组
travel ( e , H , u ) {
int v = e->v ;
if ( v != fa && vis[v] != Time ) {
get_dis ( v , dis[u] + e->c , u ) ;
}
}
}
int search ( const Node& x ) {//二分查找,如果没有找到目标,则返回小于x的最大的数的下标
int l = 0 , r = cnt ;
while ( l < r ) {
int m = ( l + r + 1 ) >> 1 ;
if ( S[m] <= x ) l = m ;
else r = m - 1 ;
}
return r ;
}
void get_ans ( int u , int sign , int fa = 0 ) {
if ( !mart[u] ) {//为不是mart的点计算能被其占领的点的个数
int tmp = search ( Node ( dis[u] , u ) ) ;
ans[u] += ( cnt - tmp ) * sign ;
}
travel ( e , H , u ) {
int v = e->v ;
if ( v != fa && vis[v] != Time ) {
get_ans ( v , sign , u ) ;
}
}
}
void deal ( int u , int val , int sign ) {
cnt = 0 ;
get_dis ( u , val ) ;//得到dist
sort ( S + 1 , S + cnt + 1 ) ;
get_ans ( u , sign ) ;//计算贡献
}
void divide ( int u ) {//分治
get_siz ( u ) ;//得到子树规模
size = siz[u] ;
root = 0 ;
get_root ( u ) ;//寻找u所在的子树的重心
vis[root] = Time ;
deal ( root , 0 , 1 ) ;//得到该子树内所有的贡献,不管是不是同一棵子树的
travel ( e , H , root ) if ( vis[e->v] != Time ) deal ( e->v , e->c , -1 ) ;//去重,排除同一棵子树的贡献
travel ( e , H , root ) if ( vis[e->v] != Time ) divide ( e->v ) ;//递归处理
}
void solve () {
int x , y , c ;
clear () ;
rep ( i , 1 , n ) {
scanf ( "%d%d%d" , &x , &y , &c ) ;
addedge ( x , y , c ) ;
addedge ( y , x , c ) ;
}
For ( i , 1 , n ) {
scanf ( "%d" , &mart[i] ) ;
if ( mart[i] ) {
node[i] = Node ( 0 , i ) ;
Q[tail ++] = i ;
} else node[i] = Node ( INF , 0 ) ;
}
++ Time ;
spfa () ;
++ Time ;
divide ( 1 ) ;
int res = 0 ;
For ( i , 1 , n ) if ( ans[i] > res ) res = ans[i] ;
printf ( "%d\n" , res ) ;
}
int main () {
Time = 0 ;
clr ( vis , 0 ) ;//优化
while ( ~scanf ( "%d" , &n ) ) solve () ;
return 0 ;
}