题目

P2656 采蘑菇

算法标签: 强连通分量, 缩点, 最长路, 最短路

思路

首先题目中给出的图是有向图, 每次采摘蘑菇之后当前位置的蘑菇数量会减少一定的值, 因此可以先缩点, 将每个联通分量内部的蘑菇都采摘完, 然后再求最长路也就是最大的蘑菇数量

代码

#include <iostream>
#include <algorithm>
#include <cstring>
#include <queue>
#include <stack>

using namespace std;

const int N = 8e4 + 10, M = 4e5 + 10;

int n, m, s;
int head1[N], head2[N], ed[M], ne[M], idx;
int fct[M], w[M];
int dfn[N], low[N], timestamp;
int stk[N], top;
bool in_stk[N];
int scc_cnt, id[N];
int d[N], vals[N];

void add(int head[], int u, int v, int val, int weight) {
   
    ed[idx] = v, ne[idx] = head[u], fct[idx] = val, w[idx] = weight, head[u] = idx++;
}

void tarjan(int u) {
   
    dfn[u] = low[u] = ++timestamp;
    stk[++top] = u;
    in_stk[u] = true;

    for (int i = head1[u]; ~i; i = ne[i]) {
   
        int v = ed[i];
        if (!dfn[v]) {
   
            tarjan(v);
            low[u] = min(low[u], low[v]);
        }
        else if (in_stk[v]) low[u] = min(low[u], dfn[v]);
    }

    if (dfn[u] == low[u]) {
   
        ++scc_cnt;
        int ver;
        do {
   
            ver = stk[top--];
            in_stk[ver] = false;
            id[ver] = scc_cnt;
        } while (ver != u);
    }
}

void calc() {
   
    for (int u = 1; u <= n; ++u) {
   
        for (int i = head1[u]; ~i; i = ne[i]) {
   
            int v = ed[i];
            if (id[u] == id[v]) {
   
                int k = w[i];
                int r = fct[i];
                while (k > 0) {
   
                    vals[id[u]] += k;
                    k = k * r / 10;
                }
            }
        }
    }

    for (int u = 1; u <= n; ++u) {
   
        for (int i = head1[u]; ~i; i = ne[i]) {
   
            int v = ed[i];
            int a = id[u], b = id[v];
            if (a != b) {
   
                add(head2, a, b, 0, w[i]);
            }
        }
    }
}


void spfa() {
   
    memset(d, -1, sizeof d);
    queue<int> q;
    q.push(id[s]);
    d[id[s]] = vals[id[s]];
    bool vis[N] = {
   0};
    vis[id[s]] = true;

    while (!q.empty()) {
   
        int u = q.front();
        q.pop();
        vis[u] = false;

        for (int i = head2[u]; ~i; i = ne[i]) {
   
            int v = ed[i];
            if (d[u] + w[i] + vals[v] > d[v]) {
   
                d[v] = d[u] + w[i] + vals[v];
                if (!vis[v]) {
   
                    q.push(v);
                    vis[v] = true;
                }
            }
        }
    }
}

int main() {
   
    ios::sync_with_stdio(false);
    cin.tie(0), cout.tie(0);

    memset(head1, -1, sizeof head1);
    memset(head2, -1, sizeof head2);

    cin >> n >> m;
    for (int i = 0; i < m; ++i) {
   
        int u, v, weight;
        double val;
        cin >> u >> v >> weight >> val;
        add(head1, u, v, (int) (val * 10), weight);
    }
    cin >> s;

    for (int i = 1; i <= n; ++i) {
   
        if (!dfn[i]) tarjan(i);
    }

    calc();
    
    spfa();

    int ans = 0;
    for (int i = 1; i <= scc_cnt; ++i) ans = max(ans, d[i]);

    cout << ans << "\n";
    return 0;
}

数组模拟队列代码

#include <iostream>
#include <algorithm>
#include <cstring>
#include <queue>

using namespace std;

const int N = 8e4 + 10, M = 4e5 + 10;

int n, m, s;
int head1[N], head2[N], ed[M], ne[M], fct[M], w[M], idx;
int dfn[N], low[N], timestamp;
int stk[N], top;
bool in_stk[N];
int scc_cnt, id[N];
int d[N], vals[N];

void add(int head[], int u, int v, int factor, int weight) {
   
    ed[idx] = v, ne[idx] = head[u], fct[idx] = factor, w[idx] = weight, head[u] = idx++;
}

void tarjan(int u) {
   
    dfn[u] = low[u] = ++timestamp;
    stk[++top] = u;
    in_stk[u] = true;

    for (int i = head1[u]; ~i; i = ne[i]) {
   
        int v = ed[i];
        if (!dfn[v]) {
   
            tarjan(v);
            low[u]=  min(low[u], low[v]);
        }
        else if (in_stk[v]) low[u] = min(low[u], dfn[v]);
    }

    if (dfn[u] == low[u]) {
   
        ++scc_cnt;
        int v;
        do {
   
            v = stk[top--];
            in_stk[v] = false;
            id[v] = scc_cnt;
        }
        while (v != u);
    }
}

void calc() {
   
    for (int u = 1; u <= n; ++u) {
   
        for (int i = head1[u]; ~i; i = ne[i]) {
   
            int v = ed[i];
            int scc1 = id[u], scc2 = id[v];
            if (scc1 == scc2) {
   
                int k = w[i];
                while (k > 0) {
   
                    vals[scc1] += k;
                    k = k * fct[i] / 10;
                }
            }
        }
    }

    for (int u = 1; u <= n; ++u) {
   
        for (int i = head1[u]; ~i; i = ne[i]) {
   
            int v = ed[i];
            int scc1 = id[u], scc2 = id[v];
            if (scc1 != scc2) add(head2, scc1, scc2, 0, w[i]);
        }
    }
}

void spfa() {
   
    int q[N], h = 0, t = -1;
    bool vis[N] = {
   0};
    int root = id[s];
    q[++t] = root;
    vis[root] = true;
    d[root] = vals[root];
    
    while (h <= t) {
   
        int u = q[h++];
        vis[u] = false;
        for (int i = head2[u]; ~i; i = ne[i]) {
   
            int v = ed[i];
            if (d[u] + w[i] + vals[v] > d[v]) {
   
                d[v] = d[u] + w[i] + vals[v];
                if (!vis[v]) {
   
                    q[++t] = v;
                    vis[v] = false;
                }
            }
        }
    }
}

int main() {
   
    ios::sync_with_stdio(false);
    cin.tie(0), cout.tie(0);

    memset(head1, -1, sizeof head1);
    memset(head2, -1, sizeof head2);

    cin >> n >> m;
    for (int i = 0; i < m; ++i) {
   
        int u, v, w;
        double fct;
        cin >> u >> v >> w >> fct;
        add(head1, u, v, (int) (fct * 10), w);
    }
    cin >> s;

    for (int i = 1; i <= n; ++i) {
   
        if (!dfn[i]) tarjan(i);
    }

    calc();
    spfa();

    int ans = 0;
    for (int i = 1; i <= scc_cnt; ++i) ans = max(ans, d[i]);

    cout << ans << "\n";
    return 0;
}