//http://www.cnblogs.com/IMGavin/
//http://hihocoder.com/problemset/problem/1455
//https://media.hihocoder.com/contests/challenge25/solution.pdf
//bitset 莫队 dfs序
/*
in1[]表示当前处理的子树的bitset状态,in2[]为in1的翻转,out1[]表示当前处理的子树外部的bitset状态,out2[]为out1的翻转
*/
#include <iostream>
#include <stdio.h>
#include <cstdlib>
#include <cstring>
#include <queue>
#include <vector>
#include <map>
#include <stack>
#include <set>
#include <bitset>
#include <cmath>
#include <algorithm>
using namespace std;
typedef long long LL;
const int INF = 0x3F3F3F3F, N = 51200, MOD = 1003;
int n;
int val[N], ic[N], oc[N];
bitset<N> ans, in1, in2, out1, out2;
int head[N];
int lf[N], ri[N], seq2node[N];
int tot, lab;
int fa[N];
struct Edge {
int to;
int next;
}edge[N * 2];
void init(){
memset(head, -1, sizeof(head));
tot = 0;
}
void add(int st, int to){
edge[tot].to =to;
edge[tot].next = head[st];
head[st]= tot++;
}
void dfs(int u, int f){
fa[u] = f;
lf[u] = ++lab;
seq2node[lab] = u;
for(int i = head[u]; i != -1; i = edge[i].next){
int v = edge[i].to;
if(v != f){
dfs(v, u);
}
}
ri[u] = lab;
}
struct node{
int l, r, u;
}q[N];
int blk;
bool cmp(const node &a, const node &b){
if(a.l / blk != b.l / blk){
return a.l / blk < b.l / blk;
}else{
return a.r < b.r;
}
}
inline void add(int x){
x = val[ seq2node[x] ];
ic[x]++;
oc[x]--;
if(ic[x] == 1){
in1[x] = 1;
in2[n + 1 - x] = 1;
}
if(oc[x] == 0){
out1[x] = 0;
out2[n + 1 - x] = 0;
}
}
inline void remove(int x){
x = val[ seq2node[x] ];
ic[x]--;
oc[x]++;
if(ic[x] == 0){
in1[x] = 0;
in2[n + 1 - x] = 0;
}
if(oc[x] == 1){
out1[x] = 1;
out2[n + 1 - x] = 1;
}
}
void solve(){
blk = (int)sqrt(n + 0.5);
sort(q + 1, q + 1 + n, cmp);
for(int i = 1; i <= n; i++){
oc[val[i]]++;
if(oc[val[i]] == 1){
out1[val[i]] = 1;
out2[n + 1 - val[i]] = 1;
}
}
int l = 1, r = 0;
for(int i = 1; i <= n; i++){
while(l < q[i].l){
remove(l);
l++;
}
while(l > q[i].l){
l--;
add(l);
}
while(r < q[i].r){
r++;
add(r);
}
while(r > q[i].r){
remove(r);
r--;
}
if(fa[q[i].u] != -1){
int v = val[fa[q[i].u]];
ans |= ((in1 >> v) & (out2>>(n + 1 - v))) | ((out1>>v) & (in2>>(n + 1 - v)));
}
}
}
int main(){
cin >> n;
for(int i = 1; i <= n; i++){
scanf("%d", &val[i]);
}
init();
for(int i = 2; i <= n; i++){
int u, v;
scanf("%d %d", &u, &v);
add(u, v);
add(v, u);
}
lab = 0;
dfs(1, -1);
for(int i = 1; i <= n; i++){
q[i].u = i;
q[i].l = lf[i];
q[i].r = ri[i];
}
solve();
int cnt = 0;
for(int i = 1; i <= n; i++){
cnt += ans[i];
}
cout<<cnt<<endl;
return 0;
}