//9.26 腾讯笔试第5题 import java.util.*; public class Main { static int[] arr = new int[100009]; static int res = 0; static class Node { int val; List<Node> childs = new LinkedList<>(); int[] table = new int[109]; public Node(int val) { this.val = val; } } static void dfs(Node rt) { for (Node tmp : rt.childs) { for (int i = 0; i < 101; ++i) { int p = (int) Math.sqrt(rt.val * i); if (tmp.table[i] > 0 && p * p == rt.val * i) { res += tmp.table[i]; } } dfs(tmp); } } public static void main(String[] args) { Scanner scan = new Scanner(System.in); int n = scan.nextInt(); if (n == 1) { System.out.println(0); return; } for (int i = 0; i < n; ++i) { arr[i] = scan.nextInt(); } Node root = new Node(arr[0]); ArrayList<Node> tree = new ArrayList<>(); tree.add(root); for (int i = 1; i < n; ++i) { int p = scan.nextInt(); Node tmp = new Node(arr[i]); tree.get(p - 1).childs.add(tmp); tree.add(tmp); } for (int i = n - 1; i >= 0; --i) { Node tmp = tree.get(i); tmp.table[tmp.val]++; for (Node c : tmp.childs) { for (int j = 0; j < 101; ++j) { tmp.table[j] += c.table[j]; } } } dfs(root); System.out.println(res); } }