//9.26 腾讯笔试第5题
import java.util.*;
public class Main {
static int[] arr = new int[100009];
static int res = 0;
static class Node {
int val;
List<Node> childs = new LinkedList<>();
int[] table = new int[109];
public Node(int val) {
this.val = val;
}
}
static void dfs(Node rt) {
for (Node tmp : rt.childs) {
for (int i = 0; i < 101; ++i) {
int p = (int) Math.sqrt(rt.val * i);
if (tmp.table[i] > 0 && p * p == rt.val * i) {
res += tmp.table[i];
}
}
dfs(tmp);
}
}
public static void main(String[] args) {
Scanner scan = new Scanner(System.in);
int n = scan.nextInt();
if (n == 1) {
System.out.println(0);
return;
}
for (int i = 0; i < n; ++i) {
arr[i] = scan.nextInt();
}
Node root = new Node(arr[0]);
ArrayList<Node> tree = new ArrayList<>();
tree.add(root);
for (int i = 1; i < n; ++i) {
int p = scan.nextInt();
Node tmp = new Node(arr[i]);
tree.get(p - 1).childs.add(tmp);
tree.add(tmp);
}
for (int i = n - 1; i >= 0; --i) {
Node tmp = tree.get(i);
tmp.table[tmp.val]++;
for (Node c : tmp.childs) {
for (int j = 0; j < 101; ++j) {
tmp.table[j] += c.table[j];
}
}
}
dfs(root);
System.out.println(res);
}
}