LC 2368. 受限条件下可到达节点的数目

时间:2024-03-05 19:54:21

2368. 受限条件下可到达节点的数目

难度 : 中等

题目大意:

现有一棵由 n 个节点组成的无向树,节点编号从 0n - 1 ,共有 n - 1 条边。

给你一个二维整数数组 edges ,长度为 n - 1 ,其中 edges[i] = [ai, bi] 表示树中节点 aibi 之间存在一条边。另给你一个整数数组 restricted 表示 受限 节点。

在不访问受限节点的前提下,返回你可以从节点 0 到达的 最多 节点数目*。*

注意,节点 0 会标记为受限节点。

提示:

  • 2 <= n <= 10^5
  • edges.length == n - 1
  • edges[i].length == 2
  • 0 <= ai, bi < n
  • ai != bi
  • edges 表示一棵有效的树
  • 1 <= restricted.length < n
  • 1 <= restricted[i] < n
  • restricted 中的所有值 互不相同

示例 1:

img

输入:n = 7, edges = [[0,1],[1,2],[3,1],[4,0],[0,5],[5,6]], restricted = [4,5]
输出:4
解释:上图所示正是这棵树。
在不访问受限节点的前提下,只有节点 [0,1,2,3] 可以从节点 0 到达。

分析

题目很清晰,我们只需要遍历到受限制的节点的时候不忘下面搜就行,所以我们可以先预处理一下,将所有受限制的节点标记成true,然后我们建图dfs一遍记数就行,另一种思路就是在建图的时候,如果没有限制在建边,否则就是直接跳过,这样下面就不用判断是不是被标记过了

DFS

class Solution {
public:
    int reachableNodes(int n, vector<vector<int>>& edges, vector<int>& restricted) {
        int st[n + 1];
        memset(st, 0, sizeof st);
        for (int x : restricted) st[x] = true;
        vector<vector<int>> g(n);
        for (int i = 0; i < edges.size(); i ++) {
            int a = edges[i][0], b = edges[i][1];
            if (st[a] || st[b]) continue;
            g[a].push_back(b);
            g[b].push_back(a);
        }
        int res = 1;
        function<void(int, int)> dfs = [&](int u, int fa) -> void {
            for (int x : g[u]) {
                if (x == fa) continue;
                res ++;
                dfs(x, u);
            }
        };
        dfs(0, -1);
        return res;
    }
};

时间复杂度 : O(n)

分析

我们可以将没有限制的节点全部放到一个集合里面,也就是并查集,用一个sz数组来维护每个集合的大小

并查集

class Solution {
public:
    int reachableNodes(int n, vector<vector<int>>& edges, vector<int>& restricted) {
        unordered_map<int, bool> st;
        for (int x : restricted) st[x] = true;
        vector<int> p(n), sz(n, 1);
        iota(p.begin(), p.end(), 0); // 初始化
        
        function<int(int)> find = [&](int x) -> int {
            if (p[x] != x) p[x] = find(p[x]);
            return p[x];
        };  
        
        for (int i = 0; i < edges.size(); i ++) {
            int a = edges[i][0], b = edges[i][1];
            if (st[a] || st[b]) continue;
            a = find(a), b = find(b);
            if (a != b) {
                p[a] = b;
                sz[b] += sz[a];
            }
        }
        return sz[find(0)];
    }
};

时间复杂度: O ( n ) O(n) O(n)

结束了