Max Depth of A Tree - Depth First Search / DFS on Tree

int max_depth=0;
void dfs(Node<int>* root, int depth) {
    if (root) {
        max_depth=std::max(max_depth, depth);
        dfs(root->left, depth+1);
        dfs(root->right, depth+1);
    }
}


int tree_max_depth(Node<int>* root) {
    dfs(root,0);
    return max_depth;
}

I came up with the state version which IMO is easier to reason about, but I think what they’re trying to accomplish here is identifying patterns / concepts from previous sections and then implementing it. That’s the ultimate goal when interviewing.

IMO it’s not clear in the explanation, but they put together a few concepts from previous chapters. From Although this can be accomplished with state, from what I understand; the solution counts “total nodes - 1” because:

  1. Depth = Number of edges from root to node
  2. Number of edges in a tree will always be n - 1, where n = number of nodes

Therefore, the max depth in a tree is equal to the longest path in the tree, i.e. number of edges from root to furthest leaf. Based on that, we know we can count all nodes in a subtree, then subtract 1 from the final result to get the correct depth, as Relic stated. Then we just need to get the maximum value out of all those.

To me, actually implementing this to get to the solution given is the confusing part. I would probably be off by 1 implementing this without practice.

I found this approach more readable:

class Node:
    def __init__(self, val, left=None, right=None):
        self.val = val
        self.left = left
        self.right = right
def is_leaf_node(node):
    return not node.left and not node.right

def tree_max_depth(root: Node) -> int:
    def dfs(root, depth):
        if  not root or is_leaf_node(root):
            return depth

        left_depth = dfs(root.left, depth + 1) if root.left else 0
        right_depth = dfs(root.right, depth + 1) if root.right else 0

        return max(left_depth, right_depth) 
    
    return dfs(root, 0)

I don’t understand how dis counts number of nodes. I don’t see any counter in there.

“There are n nodes and n - 1 edges in a tree so if we traverse each once then the total traversal is O(2n - 1) which is O(n)
this should be O(2*(n-1)). there are n-1 edges and you are traversing each of them exactly twice.

def tree_max_depth(root: Node, depth=0) -> int:
    if root is None:
        return depth

    left_depth = tree_max_depth(root.left, depth+1) if root.left else depth
    right_depth = tree_max_depth(root.right, depth+1) if root.right else depth
    return max(left_depth, right_depth)

This is mine. But I like the one Maysam_​​Gamini​1 gave more