import java.util.*;

/*
 * public class TreeNode {
 *   int val = 0;
 *   TreeNode left = null;
 *   TreeNode right = null;
 *   public TreeNode(int val) {
 *     this.val = val;
 *   }
 * }
 */

public class Solution {
    /**
     * 代码中的类名、方法名、参数名已经指定,请勿修改,直接返回方法规定的值即可
     *
     * 
     * @param root TreeNode类 
     * @return int整型
     */
    public int widthOfBinaryTree(TreeNode root) {

        if (null == root) {
            return 0;
        }
        if (null == root.left && null == root.right) {
            return 1;
        }

        ArrayList<TreeNode> leftArrayList = new ArrayList<>();
        ArrayList<TreeNode> rightArrayList = new ArrayList<>();
        Queue<TreeNode> queue = new LinkedList<>();
        HashMap<TreeNode, Integer> hashMap = new HashMap<>();
        int maxLevel = 0;
        TreeNode node = root;
        queue.add(node);
        hashMap.put(node, 1);
        int currentLevel = 0;
        while (!queue.isEmpty()) {
            node = queue.poll();
            int nodeLevel = hashMap.get(node);
            maxLevel = Math.max(maxLevel, nodeLevel);
            if (nodeLevel != currentLevel) {
                leftArrayList.add(node);
                currentLevel = nodeLevel;
            }
            if (null != node.left) {
                queue.add(node.left);
                hashMap.put(node.left, nodeLevel + 1);
            }
            if (null != node.right) {
                queue.add(node.right);
                hashMap.put(node.right, nodeLevel + 1);
            }
        }
        node = root;
        queue.add(node);
        currentLevel = 0;
        while (!queue.isEmpty()) {
            node = queue.poll();
            int nodeLevel = hashMap.get(node);
            if (nodeLevel != currentLevel) {
                rightArrayList.add(node);
                currentLevel = nodeLevel;
            }
            if (null != node.right) {
                queue.add(node.right);
            }
            if (null != node.left) {
                queue.add(node.left);
            }
        }
        node = root;
        queue.add(node);
        currentLevel = 0;
        int res = 0;
        int currentNum = 0;
        while (!queue.isEmpty()) {

            node = queue.poll();
            int nodeLevel = hashMap.get(node);

            if (leftArrayList.contains(node) && rightArrayList.contains(node)) {
                res = Math.max(res, 1);
                currentLevel = nodeLevel + 1;
                currentNum = 0;
                if (nodeLevel == maxLevel) {
                    break;
                }
            } else if (leftArrayList.contains(node)) {
                currentLevel = nodeLevel;
                currentNum = 1;
            } else if (rightArrayList.contains(node)) {
                currentNum++;
                res = Math.max(res, currentNum);
                currentLevel = nodeLevel + 1;
                currentNum = 0;
                if (nodeLevel == maxLevel) {
                    break;
                }
            } else {
                if (nodeLevel == currentLevel) {
                    currentNum++;
                }
            }

            if (null != node.left) {
                queue.add(node.left);
                hashMap.put(node.left, hashMap.get(node) + 1);
            } else {
                TreeNode nullNode = new TreeNode(-1);
                queue.add(nullNode);
                hashMap.put(nullNode, hashMap.get(node) + 1);
            }
            if (null != node.right) {
                queue.add(node.right);
                hashMap.put(node.right, hashMap.get(node) + 1);
            } else {
                TreeNode nullNode = new TreeNode(-1);
                queue.add(nullNode);
                hashMap.put(nullNode, hashMap.get(node) + 1);
            }
        }
        return res;
    }
}