import java.util.*;
/*
* public class TreeNode {
* int val = 0;
* TreeNode left = null;
* TreeNode right = null;
* }
*/
public class Solution {
/**
*
* @param root TreeNode类
* @param o1 int整型
* @param o2 int整型
* @return int整型
*/
public int lowestCommonAncestor (TreeNode root, int o1, int o2) {
// write code here
Map<TreeNode, TreeNode> map = new HashMap<>();
TreeNode f = root;
TreeNode s = root;
boolean flag1 = false, flag2 = false;
Queue<TreeNode> queues = new LinkedList<>();
if (root == null) {
return -1;
}
map.put(root, null);
queues.add(root);
while (!queues.isEmpty()) {
Queue<TreeNode> temp = new LinkedList<>();
if (queues.peek().left != null) {
TreeNode left = queues.peek().left;
map.put(left, queues.peek());
temp.add(queues.peek().left);
if (left.val == o1) {
flag1 = true;
f = left;
}
if (left.val == o2) {
flag2 = true;
s = left;
}
}
if (queues.peek().right != null) {
TreeNode right = queues.peek().right;
map.put(right, queues.peek());
temp.add(queues.peek().right);
if (right.val == o2) {
flag2 = true;
s = right;
}
if (right.val == o1) {
flag1 = true;
f = right;
}
}
if (flag1 && flag2) {
break;
}
queues.poll();
while (!temp.isEmpty()) {
queues.add(temp.poll());
}
}
Map<Integer, Integer> map1 = new HashMap<>();
Map<Integer, Integer> map2 = new HashMap<>();
int count1 = 0, count2 = 0;
while (f != null) {
map1.put(f.val, count1++);
f = map.get(f);
}
while (s != null) {
map2.put(s.val, count2++);
s = map.get(s);
}
int location = Integer.MAX_VALUE;
int ans = root.val;
for (Map.Entry<Integer, Integer> entry : map1.entrySet()) {
if (map2.containsKey(entry.getKey()) && entry.getValue() < location) {
ans = entry.getKey();
location = entry.getValue();
}
}
return ans;
}
}