Thursday, September 9, 2021

[Google Question][LeetCode] Sum of Distances in Tree

Problem: There is an undirected connected tree with n nodes labeled from 0 to n - 1 and n - 1 edges.

You are given the integer n and the array edges where edges[i] = [ai, bi] indicates that there is an edge between nodes ai and bi in the tree.

Return an array answer of length n where answer[i] is the sum of the distances between the ith node in the tree and all other nodes.

Example:

Input: n = 6, edges = [[0,1],[0,2],[2,3],[2,4],[2,5]]
Output: [8,12,6,10,10,10]
Explanation: The tree is shown above.
We can see that dist(0,1) + dist(0,2) + dist(0,3) + dist(0,4) + dist(0,5)
equals 1 + 1 + 2 + 2 + 2 = 8.
Hence, answer[0] = 8, and so on.

Input: n = 1, edges = []
Output: [0]

Input: n = 2, edges = [[1,0]]
Output: [1,1]


Approach: We can apply BFS taking every node as root and we can get our answer but this will be expensive solution as it will take O(n^2) time. Let's try to optimize it.

Let's say our tree is:

        0

      /.   \

    1.       2

  /         /    \ 

5.        3      4

Now let's say we calculated the distance for 0 which will be 8. Now let's try to calculate the distance of 1 -

     1

  /.      \

5.          0 [distance is 8]

                \

                   2

                 /.    \

               3        4

Here we already know that the sum of distances from 0 to every every node is 8. We can divide the this value into two part:

  1. distances of 0 to 1 and its children 
  2. distances of 0 to 2 and its children

Now when we calculate the same distances 1, we can say the sum of distances from 1 is going to be:

distances[0] - Number of Nodes in subtree with root as 1 + Number of nodes in subtree with root as 2 + 1 (for 0) = 

distance[0] - Number of Nodes in subtree with root as 1 + TotalNodes in tree - Number of Nodes in subtree with root as 1

Why? If you see for each node in the subtree(1), we are reducing the distance by 1 because instead of parent of 1 that is 0, we are now starting from 1. Similarly for each node in the subtree(2), we are adding 1 in the distance because instead of starting from 0, we are starting from the other 1, something like:

1 - 0.- 2 <

so the distance we need to add is the distance between 1 - 0 that is 1 to reach 2 and its children as there is no other way to reach subtree(2) from 1. If it is understood than we can have the generic formula:

distance[node] = distance[parentOfNode] - NumOfChildren[node] + (NumNodes - NumOfChildren[node])

We can use post order traversal to calculate the number of nodes in each subtree and initial distance calculation then we can use pre order traversal to calculate the final distances. Have a look at the implementation for more details.

That's all!


Implementation in C#:

public class Graph

{

    public int[] SumOfDistancesInTree(int n, int[][] edges) 

    {

        if (n == 1)

        {

            return new int[n];    

        }

        this.InitializeMembers(n);

        this.CreateGraph(edges);

        this.GetChildrenCountAndInitialResult(0, -1);   

        this.GetDistances(0, -1);

        return this.dist;

    }

    

    private void GetDistances(int node, int parent)

    {

        foreach (int child in this.graph[node])

        {

            // Bidirectional graph instead of visit set we can use this condition here.

            if (child == parent)

            {

                continue;

            }

            dist[child] = dist[node] - this.numOfChildren[child] + (this.numOfNodes - this.numOfChildren[child]);         

            this.GetDistances(child, node);

        }

    }

    

    private void GetChildrenCountAndInitialResult(int node, int parent)

    {

        foreach (int child in this.graph[node])

        {

            // Bidirectional

            if (child == parent)

            {

                continue;

            }

            this.GetChildrenCountAndInitialResult(child, node);

            this.numOfChildren[node] += this.numOfChildren[child];

            this.dist[node] += (this.dist[child] + this.numOfChildren[child]);

        }

        ++this.numOfChildren[node];

    }

    

    private void CreateGraph(int[][] edges)

    {

        foreach (int[] edge in edges)

        {

            this.graph[edge[0]].Add(edge[1]);

            this.graph[edge[1]].Add(edge[0]);

        }

    }

    

    private void InitializeMembers(int n)

    {

        this.numOfNodes = n;

        this.numOfChildren = new int[n];

        this.dist = new int[n];

        this.graph = new HashSet<int>[n];

        for (int i = 0; i < n; ++i)

        {

            this.graph[i] = new HashSet<int>();

        }

    }

    

    private HashSet<int>[] graph;

    private int[] numOfChildren;

    private int[] dist;

    private int numOfNodes;

}


Complexity: O(n)

No comments:

Post a Comment