import java.util.Scanner;
public class Main{
static final int N = 200005;
static int[] a = new int[N];
static int[] head = new int[N<<1];
static int[] next = new int[N<<1];
static int[] to = new int[N<<1];
static int cnt = 0;
static int n;
static int[] siz = new int[N<<1];
static int[] f = new int[N];
static int maxn = 0;
public static void main(String[] args){
Scanner in = new Scanner(System.in);
n = in.nextInt();
for(int i = 1; i <= n; i++) a[i] = in.nextInt();
for(int i = 1; i <= n; i++) siz[i] = 1;
for(int i = 1; i <= n; i++) f[i] = i;
for(int i = 1; i < n; i++){
int u = in.nextInt();
int v = in.nextInt();
if(a[u] != a[v]) continue;
int fau = getfa(u), fav = getfa(v);
if(fau != fav){
f[fav] = f[fau];
siz[fau] += siz[fav];
}
}
int res = 0;
for(int i = 1; i <= n; i++) maxn = Math.max(maxn, siz[i]);
StringBuilder sb = new StringBuilder();
for(int i = 1; i <= n; i++)
if(siz[getfa(i)] == maxn){
res ++;
sb.append(' ').append(i);
}
System.out.print(res + "\n" + sb.substring(1));
}
public static int getfa(int x){
if(x == f[x]) return x;
return f[x] = getfa(f[x]);
}
}