题目定义B(u)为关于点u的兴趣值 给出每个点一个rank值 如果 点 u 对 v感兴趣 则不存在一个点 k使得
rank(k) > rank(v) 且 dist(u,k) < dist(u,v);
首先最简单的思想就是 n遍最短路找每个点 rank为1…10 ~ 10的最短路dist[n][11] 然后再走每个点的最短路 dis[j] 倘若
点 i 对点 j感兴趣 那么 dis[j] < dist[j][a[i] + 1] 则ans++
首先 第一步那么我们就以rank为起点 跑10遍最短路即可
第二个我们只需要剪枝
假设存在一条边 x y z
dis[y] < dis[x] + z;
dis[y] = dis[x] + z;
if(dis[y] < dist[y][a[s] + 1]) 那么是符合条件的一定更新的
那么我们继续讨论如果 dis[y] >= dist[y][a[s] + 1]
是否存在一个点 k 需要 点 y去更新 使得 dis[k]符合条件
dis[k] < dis[y] + w[i] >= dist[y][a[s] + 1] + w[i];
dist[k][a[s] + 1] >= dist[y][a[s] + 1] + dist[k][a[y] + 1]
又因为 w[i] <= dist[k][a[y] + 1]
所以不符合更新条件
#include<iostream>
#include<cstring>
#include<cstdio>
#include<utility>
#include<queue>
#include<vector>
#define x first
#define y second
using namespace std;
const int N = 3e4 + 10,M = N * 10;
typedef pair<int,int> PII;
vector<int>v[12];
int head[N],to[M],last[M],w[M],cnt;
void add(int a,int b,int c){
to[++cnt] = b;
w[cnt] = c;
last[cnt] = head[a];
head[a] = cnt;
}
int n,m,a[N];
int dist[11][N],flag[N];
void dij(int x){
memset(flag,0,sizeof flag);
priority_queue<PII,vector<PII>,greater<PII > > q;
for(int i = 0; i < v[x].size(); i++){
int j = v[x][i];
q.push({0,j});
dist[x][j] = 0;
}
while(q.size()){
PII p = q.top();
q.pop();
if(flag[p.y]) continue;
flag[p.y] = 1;
for(int i = head[p.y]; i != -1; i = last[i]){
int j = to[i];
if(dist[x][j] > dist[x][p.y] + w[i]){
dist[x][j] = dist[x][p.y] + w[i];
q.push({dist[x][j],j});
}
}
}
}
int dis[N],vis[N],ok[N],ans;
void dij_2(int x){ //此点到别的点的最短路 若dis < dist[a[x] + 1][j] 则++
memset(ok,0,sizeof ok);
memset(vis,0,sizeof vis);
memset(dis,0x3f,sizeof dis);
dis[x] = 0;
priority_queue<PII,vector<PII>,greater<PII > > q;
q.push({0,x});
while(q.size()){
PII p = q.top();
q.pop();
if(vis[p.y]) continue;
vis[p.y] = 1;
if(!ok[p.y]){
++ans;
ok[p.y] = 1;
}
for(int i = head[p.y]; i != -1; i = last[i]){
int j = to[i];
if(dis[j] > dis[p.y] + w[i]){
dis[j] = dis[p.y] + w[i];
if(dis[j] < dist[a[x] + 1][j]){
q.push({dis[j],j});
}
}
}
}
}
int main(){
cin >> n >> m;
memset(head,-1,sizeof head);
memset(dist,0x3f,sizeof dist);
for(int i = 1; i <= n; i++){
scanf("%d",&a[i]);
v[a[i]].push_back(i);
}
for(int i = 1; i <= m; i++){
int x,y,z;
scanf("%d%d%d",&x,&y,&z);
add(x,y,z);
add(y,x,z);
}
for(int i = 10; i >= 1; i--){
dij(i);
memcpy(dist[i - 1],dist[i],sizeof dist[i]);
}
for(int i = 1; i <= n; i++){
dij_2(i);
}
cout << ans << endl;
return 0;
}