//9.26 腾讯笔试第5题

import java.util.*;

public class Main {
    static int[] arr = new int[100009];
    static int res = 0;

    static class Node {
        int val;
        List<Node> childs = new LinkedList<>();
        int[] table = new int[109];

        public Node(int val) {
            this.val = val;
        }
    }

    static void dfs(Node rt) {
        for (Node tmp : rt.childs) {
            for (int i = 0; i < 101; ++i) {
                int p = (int) Math.sqrt(rt.val * i);
                if (tmp.table[i] > 0 && p * p == rt.val * i) {
                    res += tmp.table[i];
                }
            }
            dfs(tmp);
        }
    }


    public static void main(String[] args) {
        Scanner scan = new Scanner(System.in);
        int n = scan.nextInt();
        if (n == 1) {
            System.out.println(0);
            return;
        }
        for (int i = 0; i < n; ++i) {
            arr[i] = scan.nextInt();
        }

        Node root = new Node(arr[0]);
        ArrayList<Node> tree = new ArrayList<>();
        tree.add(root);

        for (int i = 1; i < n; ++i) {
            int p = scan.nextInt();
            Node tmp = new Node(arr[i]);
            tree.get(p - 1).childs.add(tmp);
            tree.add(tmp);
        }

        for (int i = n - 1; i >= 0; --i) {
            Node tmp = tree.get(i);
            tmp.table[tmp.val]++;
            for (Node c : tmp.childs) {
                for (int j = 0; j < 101; ++j) {
                    tmp.table[j] += c.table[j];
                }
            }
        }
        dfs(root);
        System.out.println(res);
    }
}