/*
public class TreeNode {
    int val = 0;
    TreeNode left = null;
    TreeNode right = null;

    public TreeNode(int val) {
        this.val = val;

    }

}
*/
import java.util.*;

public class Solution {
    TreeNode KthNode(TreeNode pRoot, int k) {
        if(pRoot == null || k == 0){
            return null;
        }

        PriorityQueue<TreeNode> minQ = new PriorityQueue<>(k, (e1, e2)->e1.val - e2.val);

        Queue<TreeNode> queue = new LinkedList<>();
        queue.add(pRoot);

        while(!queue.isEmpty()){
            int size = queue.size();
            for(int i = 0; i < size; ++i){
                TreeNode node = queue.poll();

                minQ.add(node);
                if(node.left != null){
                    queue.add(node.left);
                }
                if(node.right != null){
                    queue.add(node.right);
                }
            }
        }

        if(minQ.size() < k){
            return null;
        }else{
            for(int i = 0; i < k - 1; ++i){
                minQ.poll();
            }
            return minQ.poll();
        }
    }
}