import java.util.*;

public class Solution {
    public TreeNode pruneLeaves (TreeNode root) {
        // 预处理
        if (root == null || needPrune(root)) return null;
        // 修剪树
        prune(root);
        // 返回修剪后的树
        return root;
    }

    // 修剪传入的树
    public static void prune(TreeNode root) {
        // 预处理
        if (root == null) return;
        // 是否需要修剪左子树
        if (root.left != null && needPrune(root.left)) {
            root.left = null; // 修剪左子树
        } else {
            prune(root.left); // 递归修剪左子树
        }
        // 是否需要修剪右子树
        if (root.right != null && needPrune(root.right)) {
            root.right = null; // 修剪右子树
        } else {
            prune(root.right); // 递归修剪右子树
        }
    }

    // 判断当前结点是否需要修剪
    public static boolean needPrune(TreeNode node) {
        // 预处理
        if (node == null) return false;
        // 存在左叶子结点
        if (node.left != null && isLeaf(node.left)) return true;
        // 存在右叶子结点
        if (node.right != null && isLeaf(node.right)) return true;
        // 其子结点中不存在叶子结点
        return false;
    }

    // 判断当前结点是否是叶子结点
    public static boolean isLeaf(TreeNode node) {
        // 预处理
        if (node == null) return false;
        // 当前结点没有子结点,则其为叶子结点
        return node.left == null && node.right == null;
    }
}