import java.util.Scanner;

public class Main{
    static final int N = 200005;
    static int[] a = new int[N];
    static int[] head = new int[N<<1];
    static int[] next = new int[N<<1];
    static int[] to = new int[N<<1];
    static int cnt = 0;
    static int n;
    static int[] siz = new int[N<<1];
    static int[] f = new int[N];
    static int maxn = 0;
    public static void main(String[] args){
        Scanner in = new Scanner(System.in);
        n = in.nextInt();
        for(int i = 1; i <= n; i++) a[i] = in.nextInt();
        for(int i = 1; i <= n; i++) siz[i] = 1;
        for(int i = 1; i <= n; i++) f[i] = i;
        for(int i = 1; i < n; i++){
            int u = in.nextInt();
            int v = in.nextInt();
            if(a[u] != a[v]) continue;
            int fau = getfa(u), fav = getfa(v);
            if(fau != fav){
                f[fav] = f[fau];
                siz[fau] += siz[fav];
            }
        }
        int res = 0;
        for(int i = 1; i <= n; i++) maxn = Math.max(maxn, siz[i]);
        StringBuilder sb = new StringBuilder();
        for(int i = 1; i <= n; i++)
            if(siz[getfa(i)] == maxn){
                res ++;
                sb.append(' ').append(i);
            }
        System.out.print(res + "\n" + sb.substring(1));
    }
    public static int getfa(int x){
        if(x == f[x]) return x;
        return f[x] = getfa(f[x]);
    }
}