Binary tree–Paths with sum

Given a binary tree with integers, write a function that calculate the number of paths that sum to a given value
Example:

Approach

For each node, traverse depth first and calculate all the paths that return the sum. Combine the results from all the nodes.
We will have two recursive functions. One function called calcPathsToSumRecursive that iterate through all the nodes. From within that function, we will have another function called getPathsRecursive that will calculate all the paths (under the given node) that returns the expected sum.

Tests

public class PathsCalculatorTest {
    @Test
    public void calcPathsToSum_twoNodes_oneSum() throws Exception {
        final BinaryTreeNode node = createTree();
        final int sum = PathsCalculator.calcPathsToSum(node, 4);
        assertEquals(1, sum);
    }

    @Test
    public void calcPathsToSum_allNodes_twoSums() throws Exception {
        final BinaryTreeNode node = createTree();
        final int sum = PathsCalculator.calcPathsToSum(node, 13);
        assertEquals(2, sum);
    }

    @Test
    public void calcPathsToSum_zeroSum() throws Exception {
        final BinaryTreeNode node = createTree();
        final int sum = PathsCalculator.calcPathsToSum(node, 0);
        assertEquals(0, sum);
    }

    @Test
    public void calcPathsToSum_emptyTree() throws Exception {
        final int sum = PathsCalculator.calcPathsToSum(new BinaryTreeNode(0), 0);
        assertEquals(1, sum);
    }

    private static BinaryTreeNode createTree() {
        final BinaryTreeNode binaryTreeNode = new BinaryTreeNode(1);
        // Root nodes
        final BinaryTreeNode n2 = new BinaryTreeNode(2);
        binaryTreeNode.setLeft(n2);
        final BinaryTreeNode n3 = new BinaryTreeNode(3);
        binaryTreeNode.setRight(n3);

        // N2 nodes
        final BinaryTreeNode n10 = new BinaryTreeNode(10);
        n2.setLeft(n10);
        final BinaryTreeNode n11 = new BinaryTreeNode(11);
        n2.setRight(n11);

        return binaryTreeNode;
    }
}

Solution

public class PathsCalculator {
    public static int calcPathsToSum(BinaryTreeNode node, int sum) {
        return calcPathsToSumRecursive(node, sum);
    }

    private static int calcPathsToSumRecursive(BinaryTreeNode node, int expectedSum) {
        if (node == null) {
            return 0;
        }

        int paths = getPaths(node, expectedSum);

        // Calculate paths to sum for every node
        paths = paths +
                calcPathsToSumRecursive(node.getLeft(), expectedSum) +
                calcPathsToSumRecursive(node.getRight(), expectedSum);

        return paths;
    }

    private static int getPaths(BinaryTreeNode node, int expectedSum) {
        return getPathsRecursive(node, expectedSum, 0);
    }

    private static int getPathsRecursive(BinaryTreeNode node, int expectedSum, int currentSum) {
        // Base case
        if (node == null) {
            return 0;
        }
        currentSum += node.getData();
        // If we found a path that leads to the sum, return 1
        if (currentSum == expectedSum) {
            return 1;
        }

        // Optimization: If the sum of the path is already bigger than sum, break early
        if (currentSum > expectedSum) {
            return 0;
        }

        return getPathsRecursive(node.getLeft(), expectedSum, currentSum) +
                getPathsRecursive(node.getRight(), expectedSum, currentSum);
    }
}

Comments