Iterative solution for finding prefix sum in a binary tree, that equals a target sum

66 Views Asked by At

I am having trouble coming up with an iterative solution for following problem (i.e. no recursion or use of call stack). Therefore, reaching out to the community here for your kind help. :)

Problem Given the root of a binary tree and an integer targetSum, return the number of paths where the sum of the values along the path equals targetSum. The path does not need to start or end at the root or a leaf, but it must go downwards (i.e., traveling only from parent nodes to child nodes).

For example: root = [10,5,-3,3,2,null,11,3,-2,null,1], sum = 8

      10
     /  \
    5   -3
   / \    \
  3   2   11
 / \   \
3  -2   1

Input: root = [10,5,-3,3,2,null,11,3,-2,null,1], targetSum = 8 Output: 3 Explanation: The paths that sum to 8 are shown.

Other Test Cases:

  • PASS Input: root = [5,4,8,11,null,13,4,7,2,null,null,5,1], targetSum = 22, Expected Output: 3
  • PASS Input: root = [-1,-2,-3], targetSum = -1, Expected Output: 1
  • PASS Input: root = [1,2,-3], targetSum = -1, Expected Output: 1
  • PASS Input: root = [1,-2,-3,1,3,-2,null,-1], targetSum = 3, Expected Output: 1
  • PASS Input: root = [1,2], targetSum = 2, Expected Output: 1
  • PASS Input: root = [1,2,null], targetSum = 2, Expected Output: 1
  • FAILED Input: root = [-2, null, -3], targetSum = -3, Output: 0, Expected Output: 1

Failed test case picture:

  -2
   \
   -3

      

All of above tests pass but the last one fails. My code returns 0 but expected output is 1, since there is one node which itself has value -3, and that's equal to the targetSum, therefore count of nodes that amount to the targetSum should be equal to 1.

I am having hard time fixing my solution so it works for the last test case, while making sure none of the other test cases fail.

My Iterative Solution:

/**
 * Definition for a binary tree node.
 * function TreeNode(val, left, right) {
 *     this.val = (val===undefined ? 0 : val)
 *     this.left = (left===undefined ? null : left)
 *     this.right = (right===undefined ? null : right)
 * }
 */
/**
 * @param {TreeNode} root
 * @param {number} targetSum
 * @return {number}
 */
var pathSum = function(root, targetSum) {
    let count = 0, cache = new Map();
    let stack = [[root, 0]];
    while (stack.length) {
        let [node, currSum] = stack.pop();
        if (node) {
            currSum += node.val;
            if (currSum === targetSum) {
                count++;
            }
            count += cache.get(currSum-targetSum) || 0;
            cache.set(currSum, (cache.get(currSum) || 0) + 1);
            stack.push([node.right, currSum]);
            stack.push([node.left, currSum]);

            if (!node.left) {
                cache.set(currSum, cache.get(currSum) - 1);
            }
        }
        
    }
    return count;
};

I have a working solution already for traversal using recursion but this question is only for iterative traversal based solution. I feel like I am missing something about how backtracking works when we use a stack for iterative traversal of the binary tree. If you spot any error, please let me know!

1

There are 1 best solutions below

0
trincot On

The problem is that you don't have a good way to know when to remove counts from your cache. At the place where you do remove a count, it follows immediately after having added the count, which cannot be right. The push instructions that happen in between don't really change that fact. It would be the same if in that case (of !node.left) you wouldn't have added the count in the first place.

This is more tricky to keep track of.

I would suggest adding another structure to keep track at which tree depth a count was added to the cache. Then when you pop a node that has a certain depth, you know that any caches that concern greater depths should be removed.

Here is your code adapted with that idea, which I didn't alter more than necessary. Comments indicate where I made changes:

var pathSum = function(root, targetSum) {
    let count = 0, cache = new Map();
    const depthLog = []; // Keep track of cache actions per depth (index in array)
    let stack = [[root, 0, 0]]; // Add the depth of the node as 3rd member
    while (stack.length) {
        let [node, currSum, depth] = stack.pop(); // Expect the 3rd member
        if (node) {
            // Remove cache that is no longer relevant
            while (depthLog.length > depth) {
                for (const sum of depthLog.pop()) {
                    cache.set(sum, cache.get(sum) - 1);
                }
            }
            //
            currSum += node.val;
            if (currSum === targetSum) {
                count++;
            }
            count += cache.get(currSum-targetSum) || 0;
            cache.set(currSum, (cache.get(currSum) || 0) + 1);
            (depthLog[depth] ??= []).push(currSum); // Also log the act of caching this sum
            stack.push([node.right, currSum, depth+1]); // Add depth
            stack.push([node.left, currSum, depth+1]);  // ...

            // Removed the `if` block at this place
        }
        
    }
    return count;
};

This extension does not impact the time complexity: the cache removals take the same complexity as the count additions (only what was added before can be removed).