代码如下
import java.util.*; class TreeNode { int val; TreeNode left; TreeNode right; public TreeNode(int val) { this.val = val; } } public class Main { static Map<Integer,TreeNode> cache = new HashMap<>(); public static void main(String[] args) { Scanner scanner = new Scanner(System.in); while(scanner.hasNext()) { int num = scanner.nextInt(); int rootNum = scanner.nextInt(); int[][] info = new int[num][3]; int n = num; int i = 0; while(n > 0) { info[i][0] = scanner.nextInt(); info[i][1] = scanner.nextInt(); info[i][2] = scanner.nextInt(); i++; n--; } TreeNode root = constructTree(rootNum,info); boolean result = isBalance(root); System.out.printf("%s", String.valueOf(result)); } } private static boolean isBalance(TreeNode root) { boolean[] result = new boolean[]{true}; getHeight(root, 0, result); return result[0]; } private static int getHeight(TreeNode root, int level, boolean[] result) { if (root == null) { return level; } int leftH = getHeight(root.left,level + 1, result); if (!result[0]) { return level; } int rightH = getHeight(root.right, level+1, result); if (!result[0]) { return level; } int dis = Math.abs(leftH - rightH); if (dis > 1) { result[0] = false; } return Math.max(leftH ,rightH); } private static TreeNode constructTree(int rootNum, int[][]info) { cache = new HashMap<>(); if (info == null || info.length == 0) { return null; } for(int i = 0;i < info.length;i++) { TreeNode cur = getOrCreateNode(info[i][0]); TreeNode left = getOrCreateNode(info[i][1]); TreeNode right = getOrCreateNode(info[i][2]); cur.left = left; cur.right = right; } return cache.get(rootNum); } private static TreeNode getOrCreateNode(int num) { if (num == 0) { return null; } TreeNode result = cache.get(num); if (result == null) { result = new TreeNode(num); cache.put(num, result); } return result; } }
后续遍历,如果根为空,那么是平衡二叉树,先检查左子树,如果不是平衡二叉树终止后续判断,如果是平衡二叉树,求出左树高度,同理获取右树高度,如果左右子树都为平衡二叉树,判断左右树高度差是否大于1,并得出当前树高。