3558.给边赋权值的方案数I

目标

给你一棵 n 个节点的无向树,节点从 1 到 n 编号,树以节点 1 为根。树由一个长度为 n - 1 的二维整数数组 edges 表示,其中 edges[i] = [ui, vi] 表示在节点 ui 和 vi 之间有一条边。

一开始,所有边的权重为 0。你可以将每条边的权重设为 1 或 2。

两个节点 u 和 v 之间路径的 代价 是连接它们路径上所有边的权重之和。

选择任意一个 深度最大 的节点 x。返回从节点 1 到 x 的路径中,边权重之和为 奇数 的赋值方式数量。

由于答案可能很大,返回它对 10^9 + 7 取模的结果。

注意: 忽略从节点 1 到节点 x 的路径外的所有边。

示例 1:

输入: edges = [[1,2]]
输出: 1
解释:
从节点 1 到节点 2 的路径有一条边(1 → 2)。
将该边赋权为 1 会使代价为奇数,赋权为 2 则为偶数。因此,合法的赋值方式有 1 种。

示例 2:

输入: edges = [[1,2],[1,3],[3,4],[3,5]]
输出: 2
解释:
最大深度为 2,节点 4 和节点 5 都在该深度,可以选择任意一个。
例如,从节点 1 到节点 4 的路径包括两条边(1 → 3 和 3 → 4)。
将两条边赋权为 (1,2) 或 (2,1) 会使代价为奇数,因此合法赋值方式有 2 种。

提示:

  • 2 <= n <= 10^5
  • edges.length == n - 1
  • edges[i] == [ui, vi]
  • 1 <= ui, vi <= n
  • edges 表示一棵合法的树。

思路

有一颗含有 n 个节点的树,编号为 1 ~ n,根节点编号是 1edges[i] = [ui, vi],表示节点 uivi 之间有一条边。定义节点 uv 之间的代价为它们之间路径上边的权重之和。可以为每一条边赋予权重 12,求使得根节点到最远叶子节点代价为奇数的赋权方案数。

定义到达当前节点代价为奇数或偶数的方案数为 dp[i][1]dp[i][0],状态转移方程为 dp[i][0] = dp[prev][0] + dp[prev][1]dp[i][1] = dp[prev][0] + dp[prev][1]

观察发现 dp[i][0] == dp[i][1],因此代价为奇数的方案数为 dp[i][1] = 2 * dp[prev][1]。假设树的深度为 d,方案数为 2^(d - 1)

求出树的深度 d,从中选取奇数条边赋值为 1,方案数为 2^(d - 1)

代码


/**
 * @date 2026-06-11 10:16
 */
public class AssignEdgeWeights3558 {

    public int assignEdgeWeights(int[][] edges) {
        int n = edges.length + 1;
        List<Integer>[] g = new ArrayList[n + 1];
        Arrays.setAll(g, i -> new ArrayList<>());
        for (int[] edge : edges) {
            g[edge[0]].add(edge[1]);
            g[edge[1]].add(edge[0]);
        }
        int d = dfs(0, 1, g);
        return pow(2, d - 1, 1000000007);
    }

    public int dfs(int fa, int cur, List<Integer>[] g) {
        int res = 0;
        for (Integer next : g[cur]) {
            if (next == fa) {
                continue;
            }
            res = Math.max(res, dfs(cur, next, g) + 1);
        }
        return res;
    }

    public int pow(int base, int exp, int mod) {
        long res = 1L;
        while (exp > 0) {
            if ((exp & 1) == 1) {
                res = res * base % mod;
            }
            base = (int) ((long) base * base % mod);
            exp >>= 1;
        }
        return (int) res;
    }

}

性能