Leetcode 834:树中距离之和(超详细的解法!!!)

给定一个无向、连通的树。树中有 N 个标记为 0...N-1 的节点以及 N-1 条边 。

i 条边连接节点 edges[i][0]edges[i][1]

返回一个表示节点 i 与其他所有节点距离之和的列表 ans

示例 1:

输入: N = 6, edges = [[0,1],[0,2],[2,3],[2,4],[2,5]]
输出: [8,12,6,10,10,10]
解释: 
如下为给定的树的示意图:
  0
 / \
1   2
   /|\
  3 4 5

我们可以计算出 dist(0,1) + dist(0,2) + dist(0,3) + dist(0,4) + dist(0,5) 
也就是 1 + 1 + 2 + 2 + 2 = 8。 因此,answer[0] = 8,以此类推。

说明: 1 <= N <= 10000

解题思路

这个问题很难,首相可以想到的一个解法是求得每个节点到其他节点的路径和,这个算法的时间复杂度是O(n),而要求所有点,那么时间复杂度就是O(n^2),显然这个问题这样做是不行的,因为题目中的N最大可以取到10000

这个问题的相关Solution给的解法非常不错,我们这里就借花献佛,直接拿过来。

Leetcode 834:树中距离之和(超详细的解法!!!)

对于上面这个图,我们将xy这条边分隔开,我们定义以左边x为根的所有节点距离和是[email protected],右边以y为根的所有节点距离和是[email protected],而x到右边y的所有点的距离和定义为[email protected],那么此时我们以x为根的结果就应该是[email protected][email protected],而[email protected][email protected]+count(y)(其中count(y)表示y的节点个数),那么

同理,此时如果我们将y作为根的话,也可以得到一个类似的结论

那么,我们就很容易知道

  • res(x)res(y)=count(y)count(x)res(x)-res(y)=count(y)-count(x)

这是一个非常有用的结论。

接着的问题就是怎么计算节点的数量和节点到它所有孩子的距离。节点数量很好计算就是将所有孩子的包含的节点数加起来然后再加上自身就行了,我们定义count(node)表示节点数,那么

  • count(node)=count(childs)+1count(node) = count(childs) + 1

接着考虑当前节点到其所有子节点的距离和。这个也非常简单,实际上和上面是一样的,我们只需要将所有孩子到其所有子节点的和加起来,然后再加上当前节点的子节点个数就行了,我们定义sum(node)表示距离和,那么

  • sum(node)=sum(childs)+count(childs)sum(node)=sum(childs)+count(childs)

那么我们只需要通过一次深度遍历就可以求得所有节点的count,但是所有的sum我们是无法求得的,因为我们只可以求和当前节点到其所有子节点的和,这个时候我们就需要联系这个节点到这个节点的子节点以外的所有节点的距离,这就要用到我们前面那个结论。

Leetcode 834:树中距离之和(超详细的解法!!!)
我们可以知道
  • res(child)res(parent)=count(parent)count(child)res(child)-res(parent)=count(parent)-count(child)

我们知道count(parent)=N - count(child),所以

  • res(parent)res(child)=2count(child)Nres(parent)-res(child) = 2*count(child)-N

所以我们在通过一次深度遍历来更新sum的结果即可,真的非常酷!!!

class Solution:
    def sumOfDistancesInTree(self, N, edges):
        """
        :type N: int
        :type edges: List[List[int]]
        :rtype: List[int]
        """
        graph = collections.defaultdict(set)
        for u, v in edges:
            graph[u].add(v)
            graph[v].add(u)

        count = [1] * N
        res = [0] * N
        def dfs1(node = 0, parent = None):
            for child in graph[node]:
                if child != parent:
                    dfs1(child, node)
                    count[node] += count[child]
                    res[node] += res[child] + count[child]

        def dfs2(node = 0, parent = None):
            for child in graph[node]:
                if child != parent:
                    res[child] = res[node] - count[child] + N - count[child]
                    dfs2(child, node)

        dfs1()
        dfs2()
        return res

reference:

https://leetcode.com/problems/sum-of-distances-in-tree/discuss/130583/C%2B%2BJavaPython-Pre-order-and-Post-order-DFS-O(N)

https://leetcode.com/problems/sum-of-distances-in-tree/solution/

我将该问题的其他语言版本添加到了我的GitHub Leetcode

如有问题,希望大家指出!!!