import java.util.*;
public class Main{
static Scanner in = new Scanner(System.in);
static int[] fa = new int[500005];
static int[] siz = new int[500005];
public static int getfa(int x){
if(fa[x] == x) return x;
return fa[x] = getfa(fa[x]);
}
public static void main(String[] args){
int n = in.nextInt();
int a = in.nextInt();
int b = in.nextInt();
boolean[] visa = new boolean[n+5];
boolean[] visb = new boolean[n+5];
int[] w = new int[n+5];
int[] u = new int[n+5];
int[] v = new int[n+5];
for(int i=1;i<=n;i++){
w[i] = in.nextInt();
}
for(int i=1;i<=n;i++){
visa[i] = false;
visb[i] = false;
fa[i] = i;
siz[i] = 1;
if(w[i] > a) visa[i] = true;
if(w[i] < b) visb[i] = true;
}
for(int i=1;i<n;i++){
u[i] = in.nextInt();
v[i] = in.nextInt();
}
for(int i=1;i<n;i++){
if(!visa[u[i]] || !visa[v[i]]) continue;
int fau = getfa(u[i]);
int fav = getfa(v[i]);
fa[fau] = fav;
siz[fav] += siz[fau];
}
long res1 = 0;
for(int i=1;i<=n;i++){
if(fa[i] == i){
res1 += (long)(siz[i])*(siz[i] - 1)/2;
}
}
for(int i=1;i<=n;i++){
fa[i] = i;
siz[i] = 1;
}
for(int i=1;i<n;i++){
if(!visb[u[i]] || !visb[v[i]]) continue;
int fau = getfa(u[i]);
int fav = getfa(v[i]);
fa[fau] = fav;
siz[fav] += siz[fau];
}
long res2 = 0;
for(int i=1;i<=n;i++){
if(fa[i] == i){
res2 += (long)(siz[i])*(siz[i] - 1)/2;
}
}
for(int i=1;i<=n;i++){
fa[i] = i;
siz[i] = 1;
}
for(int i=1;i<n;i++){
if(!visa[u[i]] || !visa[v[i]] || !visb[u[i]] || !visb[v[i]]) continue;
int fau = getfa(u[i]);
int fav = getfa(v[i]);
fa[fau] = fav;
siz[fav] += siz[fau];
}
long res3 = 0;
for(int i=1;i<=n;i++){
if(fa[i] == i){
res3 += (long)(siz[i])*(siz[i] - 1)/2;
}
}
System.out.println((long)(n)*(n - 1)/2 - res1 - res2 + res3);
}
}