All Nodes Distance K in Binary Tree

Sanjeev SharmaSanjeev Sharma
3 min read

Advertisement

Problem

Given a binary tree, target node, and integer K, return all node values that are exactly K distance from target.

Key insight: Treat the tree as an undirected graph. Use parent pointers so we can traverse upward, then BFS from target.

Approach — Parent Map + BFS

  1. DFS to build parent map
  2. BFS from target, track visited, stop at depth K

Solutions

// C — simplified with parent array (use HashMap in practice)
// See C++ for clean implementation
// C++
unordered_map<TreeNode*, TreeNode*> parent;
void buildParent(TreeNode* node, TreeNode* par) {
    if (!node) return;
    parent[node] = par;
    buildParent(node->left, node);
    buildParent(node->right, node);
}
vector<int> distanceK(TreeNode* root, TreeNode* target, int k) {
    buildParent(root, nullptr);
    queue<TreeNode*> q;
    unordered_set<TreeNode*> visited;
    q.push(target); visited.insert(target);
    int dist = 0;
    while (!q.empty()) {
        if (dist == k) {
            vector<int> res;
            while (!q.empty()) { res.push_back(q.front()->val); q.pop(); }
            return res;
        }
        int sz = q.size();
        while (sz--) {
            auto node = q.front(); q.pop();
            for (auto nb : {node->left, node->right, parent[node]}) {
                if (nb && !visited.count(nb)) {
                    visited.insert(nb);
                    q.push(nb);
                }
            }
        }
        dist++;
    }
    return {};
}
// Java
Map<TreeNode, TreeNode> parent = new HashMap<>();
void buildParent(TreeNode node, TreeNode par) {
    if (node == null) return;
    parent.put(node, par);
    buildParent(node.left, node);
    buildParent(node.right, node);
}
public List<Integer> distanceK(TreeNode root, TreeNode target, int k) {
    buildParent(root, null);
    Queue<TreeNode> q = new LinkedList<>();
    Set<TreeNode> visited = new HashSet<>();
    q.add(target); visited.add(target);
    int dist = 0;
    while (!q.isEmpty()) {
        if (dist == k) {
            List<Integer> res = new ArrayList<>();
            for (TreeNode n : q) res.add(n.val);
            return res;
        }
        int sz = q.size();
        while (sz-- > 0) {
            TreeNode node = q.poll();
            for (TreeNode nb : new TreeNode[]{node.left, node.right, parent.get(node)}) {
                if (nb != null && !visited.contains(nb)) {
                    visited.add(nb); q.add(nb);
                }
            }
        }
        dist++;
    }
    return new ArrayList<>();
}
// JavaScript
function distanceK(root, target, k) {
    const parent = new Map();
    function buildParent(node, par) {
        if (!node) return;
        parent.set(node, par);
        buildParent(node.left, node);
        buildParent(node.right, node);
    }
    buildParent(root, null);
    const visited = new Set();
    let queue = [target];
    visited.add(target);
    let dist = 0;
    while (queue.length) {
        if (dist === k) return queue.map(n => n.val);
        const next = [];
        for (const node of queue) {
            for (const nb of [node.left, node.right, parent.get(node)]) {
                if (nb && !visited.has(nb)) {
                    visited.add(nb);
                    next.push(nb);
                }
            }
        }
        queue = next;
        dist++;
    }
    return [];
}
# Python
from collections import defaultdict, deque

def distanceK(root, target, k):
    parent = {}
    def build(node, par):
        if not node:
            return
        parent[node] = par
        build(node.left, node)
        build(node.right, node)
    build(root, None)

    visited = {target}
    q = deque([target])
    dist = 0
    while q:
        if dist == k:
            return [node.val for node in q]
        for _ in range(len(q)):
            node = q.popleft()
            for nb in [node.left, node.right, parent.get(node)]:
                if nb and nb not in visited:
                    visited.add(nb)
                    q.append(nb)
        dist += 1
    return []

Complexity

  • Time: O(n)
  • Space: O(n)

Key Insight

A tree is a DAG — add parent pointers to make it undirectional, then BFS gives exact distances.

Advertisement

Sanjeev Sharma

Written by

Sanjeev Sharma

Full Stack Engineer · E-mopro