import java.util.Scanner; import java.util.ArrayList; // 注意类名必须为 Main, 不要有任何 package xxx 信息 public class Main { static int res = Integer.MIN_VALUE; public static void main(String[] args) { Scanner in = new Scanner(System.in); int n = in.nextInt(); int[] value = new int[n]; int[] parent = new int[n]; for (int i = 0; i < n; i++) value[i] = in.nextInt(); for (int i = 0; i < n; i++) parent[i] = in.nextInt(); if(n==1){ System.out.println(value[0]); return; } ArrayList<Integer>[] relations = new ArrayList[n]; //relations[i]表示节点i的子节点 for (int i = 1; i < n; i++) { if(relations[parent[i]-1]==null) relations[parent[i]-1] = new ArrayList(); relations[parent[i]-1].add(i); } maxPasserano(value,relations,0); System.out.println(res); } private static int maxPasserano(int[] value,ArrayList<Integer>[] relations, int curNodeId){ int left = 0;int right = 0; int max = Integer.MIN_VALUE; //System.out.println(curNodeId); if(relations[curNodeId]==null||relations[curNodeId].size()==0) return value[curNodeId]; if(relations[curNodeId].size()>=1) left = maxPasserano(value,relations,relations[curNodeId].get(0)); if(relations[curNodeId].size()==2) right= maxPasserano(value,relations,relations[curNodeId].get(1)); res = Math.max(res, Math.max(left,right) + value[curNodeId]); res = Math.max(res, left+right + value[curNodeId]); res = Math.max(res,value[curNodeId]); return Math.max(left,right) + value[curNodeId]; } }