/**
* struct TreeNode {
* int val;
* struct TreeNode *left;
* struct TreeNode *right;
* TreeNode(int x) : val(x), left(nullptr), right(nullptr) {}
* };
*/
class Solution {
private:
void buildGraph(TreeNode* parent, TreeNode* child,
unordered_map<TreeNode*, vector<TreeNode*>>& graph) {
if (child == nullptr) return;
if (parent != nullptr) {
graph[parent].push_back(child);
graph[child].push_back(parent);
}
buildGraph(child, child->left, graph);
buildGraph(child, child->right, graph);
}
TreeNode* findTarget(TreeNode* node, int target) {
if (node == nullptr) return nullptr;
if (node->val == target) return node;
TreeNode* leftResult = findTarget(node->left, target);
if (leftResult != nullptr) return leftResult;
return findTarget(node->right, target);
}
public:
/**
* 代码中的类名、方法名、参数名已经指定,请勿修改,直接返回方法规定的值即可
*
*
* @param root TreeNode类
* @param target int整型
* @param k int整型
* @return int整型vector
*/
vector<int> distanceKnodes(TreeNode* root, int target, int k) {
// write code here
unordered_map<TreeNode*, vector<TreeNode*>> graph;
buildGraph(nullptr, root, graph);
TreeNode* targetNode = findTarget(root, target);
vector<int> result;
if (!targetNode) return result;
// BFS to find all nodes at distance K
queue<TreeNode*> q;
unordered_map<TreeNode*, int> visited;
q.push(targetNode);
visited[targetNode] = 0;
while (!q.empty()) {
TreeNode* current = q.front();
q.pop();
int currDistance = visited[current];
if (currDistance == k) {
result.push_back(current->val);
}
for (TreeNode* neighbor : graph[current]) {
if (visited.find(neighbor) == visited.end()) {
visited[neighbor] = currDistance + 1;
q.push(neighbor);
}
}
}
return result;
}
};