Path Sum

Problem

You are given a binary tree in which each node contains an integer value.

Find the number of paths that sum to a given value.

The path does not need to start or end at the root or a leaf, but it must go downwards (traveling only from parent nodes to child nodes).

The tree has no more than 1,000 nodes and the values are in the range -1,000,000 to 1,000,000.

Example:

root = [10,5,-3,3,2,null,11,3,-2,null,1], sum = 8

Return 3. The paths that sum to 8 are:

  1. 5 -> 3
  2. 5 -> 2 -> 1
  3. -3 -> 11

Thoughts

So the idea is similar as Two sum, using HashMap to store ( key : the prefix sum, value : how many ways get to this prefix sum) , and whenever reach a node, we check if prefix sum - target exists in hashmap or not, if it does, we added up the ways of prefix sum - target into res.
For instance : in one path we have 1,2,-1,-1,2, then the prefix sum will be: 1, 3, 2, 1, 3, let’s say we want to find target sum is 2, then we will have {2},{ 2, -1, -1, 2} ways.

I used global variable count, but obviously we can avoid global variable by passing the count from bottom up. The time complexity is O(n). This is my first post in discuss, open to any improvement or criticism. 🙂

Code

package Algorithms;
import java.util.HashMap;
import java.util.Map;
public class Solution437 {
    public static int pathSum(TreeNode root, int sum){
        HashMap<Integer,Integer> map=new HashMap<Integer,Integer>();
        map.put(0, 1);
        helper(root,0,sum,map);
        return count;
    }
    static int count=0;
    public static void helper(TreeNode root,int sum,int target,HashMap<Integer,Integer> map){
        if(root==null) return;
        sum+=root.val;

        if(map.containsKey(sum-target)){
            count+=map.get(sum-target);
        }
        if(!map.containsKey(sum)){
            map.put(sum, 1);
        }else{
            map.put(sum, map.get(sum)+1);
        }
        helper(root.left,sum,target,map);
        helper(root.right,sum,target,map);
        map.put(sum, map.get(sum)-1);
    }
    public static void main(String[] args) {
        TreeNode t1=new TreeNode(10);
        TreeNode t2=new TreeNode(5);
        TreeNode t3=new TreeNode(-3);
        TreeNode t4=new TreeNode(3);
        TreeNode t5=new TreeNode(2);
        TreeNode t6=new TreeNode(11);
        TreeNode t7=new TreeNode(3);
        TreeNode t8=new TreeNode(-2);
        TreeNode t9=new TreeNode(1);
        t1.left=t2;
        t1.right=t3;
        t2.left=t4;
        t2.right=t5;
        t3.right=t6;
        t4.left=t7;
        t4.right=t8;
        t5.right=t9;
        System.out.println(pathSum(t1,8));
    }
}