题目

356. 次小生成树

算法标签: K r u s k a l Kruskal Kruskal, M S T MST MST, 倍增优化, l c a lca lca

思路

因为要求的是严格次小生成树, 假设最小生成树的和为 s s s, 遍历每一个非树边 e e e, 那么最后答案就是 min ⁡ ( s + e . w − v ) \min (s + e.w - v) min(s+e.wv), v v v是当前非树边的路径上最长的边
如果最长边是最小生成树的边, 使用次长边, 因此需要处理两个数组 d 1 d1 d1, d 2 d2 d2分别代表最长边和次长边, 算法时间复杂度 O ( E log ⁡ E ) O(E\log E) O(ElogE)

代码

#include <iostream>
#include <algorithm>
#include <cstring>
#include <vector>

using namespace std;

typedef long long LL;
const int N = 1e5 + 10, M = 6e5 + 10, K = 18;
const int INF = 0x3f3f3f3f;

int n, m;
struct Edge {
   
    int u, v, w;
    bool is_tr = false;
    bool operator< (const Edge &e) const {
   
        return w < e.w;
    }
} edges[M];
int head[N], ed[N << 1], ne[N << 1], w[N << 1], idx;
int fa[N][K], depth[N];
int d1[N][K], d2[N][K];
int p[N];

int find(int x) {
   
    if (p[x] != x) p[x] = find(p[x]);
    return p[x];
}

void add(int u, int v, int val) {
   
    ed[idx] = v, ne[idx] = head[u], w[idx] = val, head[u] = idx++;
}

LL kruskal() {
   
    sort(edges, edges + m);
    for (int i = 0; i <= n; ++i) p[i] = i;
    LL res = 0;
    for (int i = 0; i < m; ++i) {
   
        auto &[u, v, w, is_tr] = edges[i];
        int fa1 = find(u), fa2 = find(v);
        if (fa1 == fa2) continue;
        res += w;
        p[fa2] = fa1;
        is_tr = true;
        add(u, v, w), add(v, u, w);
    }
    return res;
}

void bfs() {
   
    int q[N], h = 0, t = -1;
    q[++t] = 1;
    memset(depth, 0x3f, sizeof depth);
    depth[0] = 0, depth[1] = 1;

    while (h <= t) {
   
        int u = q[h++];
        for (int i = head[u]; ~i; i = ne[i]) {
   
            int v = ed[i];
            if (depth[u] + 1 < depth[v]) {
   
                depth[v] = depth[u] + 1;
                q[++t] = v;

                fa[v][0] = u;
                d1[v][0] = w[i];
                d2[v][0] = -INF;
                for (int k = 1; k < K; ++k) {
   
                    int mid = fa[v][k - 1];
                    fa[v][k] = fa[mid][k - 1];

                    int arr[4] = {
   
                            d1[v][k - 1],
                            d2[v][k - 1],
                            d1[mid][k - 1],
                            d2[mid][k - 1]
                    };

                    int val1 = -INF, val2 = -INF;
                    for (int j = 0; j < 4; ++j) {
   
                        if (arr[j] > val1) val2 = val1, val1 = arr[j];
                        else if (arr[j] > val2 && arr[j] < val1) val2 = arr[j];
                    }
                    d1[v][k] = val1;
                    d2[v][k] = val2;
                }
            }
        }
    }
}

int lca(int u, int v, int val) {
   
    if (depth[u] < depth[v]) swap(u, v);
    vector<int> vec;

    for (int k = K - 1; k >= 0; --k) {
   
        if (depth[fa[u][k]] >= depth[v]) {
   
            vec.push_back(d1[u][k]);
            vec.push_back(d2[u][k]);
            u = fa[u][k];
        }
    }

    if (u != v) {
   
        for (int k = K - 1; k >= 0; --k) {
   
            if (fa[u][k] != fa[v][k]) {
   
                vec.push_back(d1[u][k]);
                vec.push_back(d2[u][k]);
                vec.push_back(d1[v][k]);
                vec.push_back(d2[v][k]);
                u = fa[u][k], v = fa[v][k];
            }
        }
        vec.push_back(d1[u][0]);
        vec.push_back(d2[u][0]);
        vec.push_back(d1[v][0]);
        vec.push_back(d2[v][0]);
    }

    int val1 = -INF, val2 = -INF;
    for (int t : vec) {
   
        if (t > val1) val2 = val1, val1 = t;
        else if (t > val2) val2 = t;
    }

    if (val1 < val) return val1;
    if (val1 == val) return val2;
    return -INF;
}

int main() {
   
    ios::sync_with_stdio(false);
    cin.tie(0), cout.tie(0);

    memset(head, -1, sizeof head);
    cin >> n >> m;
    for (int i = 0; i < m; ++i) {
   
        int u, v, w;
        cin >> u >> v >> w;
        edges[i] = {
   u, v, w};
    }

    LL sum = kruskal();
    bfs();

    LL ans = 1e18;
    for (int i = 0; i < m; ++i) {
   
        auto &[u, v, w, is_tr] = edges[i];
        if (is_tr) continue;
        ans = min(ans, sum + w - lca(u, v, w));
    }

    cout << ans << "\n";
    return 0;
}

*警示后人

因为求的是严格次小生成树, 因此在计算的时候不是arr[j] > val2而是arr[j] > val2 && arr[j] < val1, 否则求出的就不是严格的了