题解

A

注意到给定的 并没有用,只需要记录每个 对应的最小的 即可。

#include <bits/stdc++.h>

using i64 = long long;

int main() {
    std::ios::sync_with_stdio(false);
    std::cin.tie(nullptr);

    int n, q;
    std::cin >> n >> q;

    std::string s;
    std::cin >> s;

    i64 ans = 0;
    std::vector<int> a(n, n);

    while (q--) {
        int l, r;
        std::cin >> l >> r;
        l--;
        r--;

        ans -= n - a[l];
        a[l] = std::min(a[l], r);
        ans += n - a[l];

        std::cout << ans << "\n";
    }

    return 0;
}

B

首先特判全 ,答案为

否则,只要存在某一位 满足,存在一种方案选出奇数个 在任意 位上是 ,其他 在任意 位上是 ,即可让答案的 位全为

对于每个 ,记录 中最低位的 所在位。

表示考虑前 个数,选出了偶数个数的最低位的 的最高位的最小值,记 表示考虑前 个数,选出了奇数个数的最低位的 中的最高位的最小值。

转移为:

最终答案即为

复杂度为

Bonus: 本题存在严格 的做法

#include <bits/stdc++.h>

using u32 = unsigned int;
using u64 = unsigned long long;

int main() {
    std::ios::sync_with_stdio(false);
    std::cin.tie(nullptr);

    int t;
    std::cin >> t;

    while (t--) {
        int n;
        std::cin >> n;

        std::vector<u64> a(n);
        for (auto &a : a) {
            std::cin >> a;
        }

        std::vector f(2, std::vector(n, 60));
        for (int i = 0; i < n; i++) {
            if (a[i] == 0) {
                f[0][i] = 0;
            } else {
                f[0][i] = std::min(std::countr_zero(~a[i]), 60);
                f[1][i] = std::min(std::countr_zero(a[i]), 60);
            }
        }

        std::array<int, 2> dp{0, 60};
        for (int i = 0; i < n; i++) {
            std::array<int, 2> ndp{60, 60};
            for (int x = 0; x < 2; x++) {
                for (int y = 0; y < 2; y++) {
                    ndp[x ^ y] = std::min(ndp[x ^ y], std::max(dp[x], f[y][i]));
                }
            }
            dp = std::move(ndp);
        }

        u64 ans = 0;
        if (std::count(a.begin(), a.end(), 0) != n) {
            ans = ((1ULL << 61) - 1) >> dp[1] << dp[1];
        }
        std::cout << ans << "\n";
    }

    return 0;
}

C

给定的是个排列每个数只出现一次所以套路的枚举,对于每个 求出必须包含的区间 , 即能包含 所有数的最小区间,然后考虑 的位置 我们令这个位置为 , 分类讨论一下

  • 如果,则此时的 一定不合法,++即可

  • 如果 则当前有贡献的区间为左端点在右端点在的所有区间

    ​ 我们可以枚举左端点 然后算右端点的贡献。

    ​ 然后考虑怎么固定了一个端点后怎么快速算出另一个端点所有的贡献。我们假设现在固定的左端点为

    ​ 则此时答案为

    ​ 这个式子可以通过错位相接算出通项公式,令

    ② ​ ①

    ​ 复杂度为快速幂的复杂度

  • 如果 则当前有贡献的区间为左端点在右端点在的所有区间

    ​ 我们可以枚举右端点 然后算左端点的贡献。

注意单独算一下的情况

因为一个点如果在时被枚举了一次,那么 时,必须包含的区间必定含有这个点 。

所以每个点最多被枚举一次,所以复杂度为

#include <bits/stdc++.h>
#define int long long
using namespace std;
typedef long long ll;
int read() {
    int x = 0, f = 1;
    char c = getchar();
    while(c < '0' || c > '9')
        f = (c == '-') ? -1 : 1, c = getchar();
    while(c >= '0' && c <= '9')
        x = x * 10 + c - 48, c = getchar();
    return f * x;
}
const int N = 2e5 + 10;
constexpr int mod = 998244353;
int a[N];
int ksm(int a, int b) {
    int ans = 1;
    while(b) {
        if(b & 1)
            ans = ans * a % mod;
        a = a * a % mod;
        b >>= 1;
    }
    return ans;
}
int vis[N];
int calc(int m, int l, int r) {
    int s = l * ksm(2, m * l) % mod;
    s = (s - r * ksm(2, (r + 1) * m) % mod + mod) % mod;
    int s1 = (ksm(2, (r + 1) * m) - ksm(2, (l + 1) * m) + mod) % mod;
    s = (s + s1 * ksm(ksm(2, m) - 1, mod - 2) % mod) % mod;
    int ans = m * s % mod * ksm((1 - ksm(2, m) + mod) % mod, mod - 2) % mod;
    return ans;
}
signed main() {
    int n = read();
    for(int i = 1; i <= n; i++)
        a[i] = read(), vis[a[i]] = i;
    int l = n, r = 0, mex = 1;
    int ans = n * (n + 1) % mod;
    ans = ans * ksm(2, n * (n + 1)) % mod;
    for(int i = 1; i < vis[1]; i++)
        ans = (ans + calc(1, 1, vis[1] - i)) % mod;
    for(int i = vis[1] + 1; i <= n; i++)
        ans = (ans + calc(1, 1, n - i + 1)) % mod;
    while(mex < n) {
        l = min(l, vis[mex]);
        r = max(r, vis[mex]);
        while(mex < n && vis[mex + 1] <= r && vis[mex + 1] >= l)
            mex++;
        if(mex == n)
            break;
        int p = vis[mex + 1];
        int L = 1, R = n;
        if(p < l) {
            L = p + 1;
            for(int i = L; i <= l; i++)
                ans = (ans + calc(mex + 1, r - i + 1, n - i + 1)) % mod;
        } else {
            R = p - 1;
            for(int i = r; i <= R; i++)
                ans = (ans + calc(mex + 1, i - l + 1, i)) % mod;
        }
        mex++;
    }
    cout << ans;
}

D

首先,注意到题目给定的 等价于对一个 进制数反复求数位和直到只剩一位。设原数为 ,最后剩下的一位数为 ,容易证明

因此存在一个简单的复杂度为 做法。首先考虑变换所有的 ,枚举当前位为第 位,前 位的和模 ,枚举当前位的数字为 ,转移方程为:

然后考虑所有的 ,只能最多变换一个位置。考虑将 变换为 ,等价于让

容易发现这是一个卷积形式,因此可以使用 优化至

特别需要注意的是,,因此还需要特殊处理答案为 的方案数。

#include <bits/stdc++.h>

using u32 = unsigned int;
using i64 = long long;
using u64 = unsigned long long;
using u128 = unsigned __int128;

template <class T> constexpr T power(T a, u64 n) {
    T res{1};
    for (; n != 0; n /= 2, a *= a) {
        if (n % 2 == 1) {
            res *= a;
        }
    }
    return res;
}

template <u32 P> constexpr u32 mulMod(u32 a, u32 b) {
    return static_cast<u64>(a) * b % P;
}

template <u64 P> constexpr u64 mulMod(u64 a, u64 b) {
    return static_cast<u128>(a) * b % P;
}

template <class U, U P> struct ModBase {
    U x;
    constexpr ModBase() : x{0} {}
    template <class T> constexpr ModBase(T x) : x{norm(x % T(mod()))} {}
    static constexpr U mod() {
        return P;
    }
    static constexpr U norm(U x) {
        if ((x >> (8 * sizeof(U) - 1) & 1) == 1) {
            x += mod();
        }
        if (x >= mod()) {
            x -= mod();
        }
        return x;
    }
    constexpr U val() const {
        return x;
    }
    constexpr ModBase operator-() const {
        return ModBase(mod() - x);
    }
    constexpr ModBase inv() const {
        return power(*this, mod() - 2);
    }
    constexpr ModBase &operator+=(const ModBase &rhs) & {
        x = norm(x + rhs.x);
        return *this;
    }
    constexpr ModBase &operator-=(const ModBase &rhs) & {
        x = norm(x - rhs.x);
        return *this;
    }
    constexpr ModBase &operator*=(const ModBase &rhs) & {
        x = mulMod<mod()>(x, rhs.x);
        return *this;
    }
    constexpr ModBase &operator/=(const ModBase &rhs) & {
        return *this *= rhs.inv();
    }
    friend constexpr ModBase operator+(ModBase lhs, const ModBase &rhs) {
        return lhs += rhs;
    }
    friend constexpr ModBase operator-(ModBase lhs, const ModBase &rhs) {
        return lhs -= rhs;
    }
    friend constexpr ModBase operator*(ModBase lhs, const ModBase &rhs) {
        return lhs *= rhs;
    }
    friend constexpr ModBase operator/(ModBase lhs, const ModBase &rhs) {
        return lhs /= rhs;
    }
    friend constexpr std::ostream &operator<<(std::ostream &os, const ModBase &rhs) {
        return os << rhs.val();
    }
    friend constexpr bool operator==(const ModBase &lhs, const ModBase &rhs) {
        return lhs.val() == rhs.val();
    }
    friend constexpr bool operator!=(const ModBase &lhs, const ModBase &rhs) {
        return lhs.val() != rhs.val();
    }
    friend constexpr bool operator<(const ModBase &lhs, const ModBase &rhs) {
        return lhs.val() < rhs.val();
    }
};

template <u32 P> using Mod32 = ModBase<u32, P>;
template <u64 P> using Mod64 = ModBase<u64, P>;

constexpr u32 P = 998244353;
using Z = Mod32<P>;

template <int V> constexpr Z CInv = Z(V).inv();

constexpr Z findPrimitiveRoot() {
    Z i = 2;
    int k = std::countr_zero(P - 1);
    while (true) {
        if (power(i, (P - 1) / 2) != 1) {
            break;
        }
        i += 1;
    }
    return power(i, (P - 1) >> k);
}

constexpr Z primitiveRoot = findPrimitiveRoot();

namespace NTT {
    std::vector<int> rev;
    std::vector<Z> roots{0, 1};
    void dft(std::vector<Z> &a) {
        int n = a.size();
        if (rev.size() != n) {
            int k = std::countr_zero(u32(n)) - 1;
            rev.resize(n);
            for (int i = 0; i < n; i++) {
                rev[i] = rev[i >> 1] >> 1 | (i & 1) << k;
            }
        }
        for (int i = 0; i < n; i++) {
            if (rev[i] < i) {
                std::swap(a[i], a[rev[i]]);
            }
        }
        if (roots.size() < n) {
            int k = std::countr_zero(roots.size());
            roots.resize(n);
            while ((1 << k) < n) {
                auto e = power(primitiveRoot, 1 << (std::countr_zero(P - 1) - k - 1));
                for (int i = 1 << (k - 1); i < (1 << k); i++) {
                    roots[2 * i] = roots[i];
                    roots[2 * i + 1] = roots[i] * e;
                }
                k++;
            }
        }
        for (int k = 1; k < n; k *= 2) {
            for (int i = 0; i < n; i += 2 * k) {
                for (int j = 0; j < k; j++) {
                    auto u = a[i + j];
                    auto v = a[i + j + k] * roots[k + j];
                    a[i + j] = u + v;
                    a[i + j + k] = u - v;
                }
            }
        }
    }
    void idft(std::vector<Z> &a) {
        int n = a.size();
        std::reverse(a.begin() + 1, a.end());
        dft(a);
        auto inv = Z(n).inv();
        for (int i = 0; i < n; i++) {
            a[i] *= inv;
        }
    }
} // namespace NTT

struct Poly : std::vector<Z> {
    using vector::vector;
    constexpr friend Poly operator*(Poly lhs, Poly rhs) {
        if (lhs.empty() || rhs.empty()) {
            return Poly();
        }
        int n = 1, m = lhs.size() + rhs.size() - 1;
        while (n < m) {
            n *= 2;
        }
        lhs.resize(n);
        rhs.resize(n);
        NTT::dft(lhs);
        NTT::dft(rhs);
        for (int i = 0; i < n; i++) {
            lhs[i] *= rhs[i];
        }
        NTT::idft(lhs);
        lhs.resize(m);
        return lhs;
    }
};

int main() {
    std::ios::sync_with_stdio(false);
    std::cin.tie(nullptr);

    int n, k;
    std::cin >> n >> k;

    std::vector<int> a(n);
    for (auto &a : a) {
        std::cin >> a;
    }

    std::vector<std::vector<int>> s(n);
    for (int i = 0; i < n; i++) {
        int c;
        std::cin >> c;

        s[i].resize(c);
        for (auto &s : s[i]) {
            std::cin >> s;
        }
    }

    Poly f(k - 1);
    f[0] = 1;
    for (int i = 0; i < n; i++) {
        Poly g(k - 1);
        if (a[i] == -1) {
            for (auto x : s[i]) {
                g[x % (k - 1)] += 1;
            }
        } else {
            g[a[i] % (k - 1)] += 1;
        }
        f = f * g;
        for (int i = 0; i < k - 1; i++) {
            f[i] += f[i + k - 1];
        }
        f.resize(k - 1);
    }

    std::vector<Z> ans(k);
    for (int i = 0; i < k - 1; i++) {
        ans[i] = f[i];
    }

    for (int i = 0; i < n; i++) {
        if (a[i] != -1) {
            Poly g(k - 1);
            for (auto x : s[i]) {
                g[(x - (a[i] % (k - 1)) + k - 1) % (k - 1)] += 1;
            }
            g = f * g;
            for (int j = 0; j < k - 1; j++) {
                ans[j] += g[j] + g[j + k - 1];
            }
        }
    }

    auto zero = [&]() {
        int f = 1, g = 1;
        for (int i = 0; i < n; i++) {
            if (a[i] != -1) {
                f &= a[i] == 0;
            } else {
                g &= std::find(s[i].begin(), s[i].end(), 0) != s[i].end();
            }
        }
        if (f & g) {
            int cnt = 0;
            for (int i = 0; i < n; i++) {
                if (a[i] != -1) {
                    cnt += std::find(s[i].begin(), s[i].end(), 0) != s[i].end();
                }
            }
            return cnt + 1;
        } else if (g) {
            int cnt = 0;
            for (int i = 0; i < n; i++) {
                if (a[i] != -1) {
                    cnt += a[i] != 0;
                }
            }
            if (cnt > 1) {
                return 0;
            } else {
                for (int i = 0; i < n; i++) {
                    if (a[i] != -1 && a[i] != 0) {
                        return int(std::find(s[i].begin(), s[i].end(), 0) != s[i].end());
                    }
                }
                return 0;
            }
        } else {
            return 0;
        }
    };

    int cnt = zero();
    ans[k - 1] = ans[0] - cnt;
    ans[0] = cnt;

    for (int i = 0; i < k; i++) {
        std::cout << ans[i] << " \n"[i + 1 == k];
    }

    return 0;
}

E

存在一个简单的 做法,即设 为考虑以 为根的子树,序列的第一个节点的权值为 ,序列的最后一个节点为 的方案数, 转移为:

朴素 的复杂度为 ,无法通过。

可以使用线段树合并优化至

#include <bits/stdc++.h>

using u32 = unsigned int;
using i64 = long long;
using u64 = unsigned long long;
using u128 = unsigned __int128;

template <class T> constexpr T power(T a, u64 n) {
    T res{1};
    for (; n != 0; n /= 2, a *= a) {
        if (n % 2 == 1) {
            res *= a;
        }
    }
    return res;
}

template <u32 P> constexpr u32 mulMod(u32 a, u32 b) {
    return u64(a) * b % P;
}

template <u64 P> constexpr u64 mulMod(u64 a, u64 b) {
    return u128(a) * b % P;
}

template <class U, U P> struct ModBase {
    U x;
    constexpr ModBase() : x{0} {}
    template <class T> constexpr ModBase(T x) : x{norm(x % T(mod()))} {}
    static constexpr U mod() {
        return P;
    }
    static constexpr U norm(U x) {
        if ((x >> (8 * sizeof(U) - 1) & 1) == 1) {
            x += mod();
        }
        if (x >= mod()) {
            x -= mod();
        }
        return x;
    }
    constexpr U val() const {
        return x;
    }
    constexpr ModBase operator-() const {
        return ModBase(mod() - x);
    }
    constexpr ModBase inv() const {
        return power(*this, mod() - 2);
    }
    constexpr ModBase &operator+=(const ModBase &rhs) & {
        x = norm(x + rhs.x);
        return *this;
    }
    constexpr ModBase &operator-=(const ModBase &rhs) & {
        x = norm(x - rhs.x);
        return *this;
    }
    constexpr ModBase &operator*=(const ModBase &rhs) & {
        x = mulMod<mod()>(x, rhs.x);
        return *this;
    }
    constexpr ModBase &operator/=(const ModBase &rhs) & {
        return *this *= rhs.inv();
    }
    friend constexpr ModBase operator+(ModBase lhs, const ModBase &rhs) {
        return lhs += rhs;
    }
    friend constexpr ModBase operator-(ModBase lhs, const ModBase &rhs) {
        return lhs -= rhs;
    }
    friend constexpr ModBase operator*(ModBase lhs, const ModBase &rhs) {
        return lhs *= rhs;
    }
    friend constexpr ModBase operator/(ModBase lhs, const ModBase &rhs) {
        return lhs /= rhs;
    }
    friend constexpr std::ostream &operator<<(std::ostream &os, const ModBase &rhs) {
        return os << rhs.val();
    }
    friend constexpr bool operator==(const ModBase &lhs, const ModBase &rhs) {
        return lhs.val() == rhs.val();
    }
    friend constexpr bool operator!=(const ModBase &lhs, const ModBase &rhs) {
        return lhs.val() != rhs.val();
    }
    friend constexpr bool operator<(const ModBase &lhs, const ModBase &rhs) {
        return lhs.val() < rhs.val();
    }
};

template <u32 P> using Mod32 = ModBase<u32, P>;
template <u64 P> using Mod64 = ModBase<u64, P>;

constexpr u32 P = 998244353;
using Z = Mod32<P>;

template <int V> constexpr Z inv = Z(V).inv();

template <class Info, class Tag> struct NodeBase {
    inline static int n;
    NodeBase *left, *right;
    Info info;
    Tag tag;
    NodeBase() : left{}, right{}, info{}, tag{} {}
};

struct Tag {
    Z x;
    Tag(Z x = 1) : x{x} {}
    void apply(const Tag &v) {
        x *= v.x;
    }
};

struct Info {
    Z x;
    Info(Z x = 0) : x{x} {}
    void apply(const Tag &v) {
        x *= v.x;
    }
};

Info operator+(const Info &lhs, const Info &rhs) {
    Info res;
    res.x = lhs.x + rhs.x;
    return res;
}

using Node = NodeBase<Info, Tag>;

void apply(Node *&p, const Tag &v) {
    p->info.apply(v);
    p->tag.apply(v);
}
void push(Node *&p) {
    if (p->left) {
        apply(p->left, p->tag);
    }
    if (p->right) {
        apply(p->right, p->tag);
    }
    p->tag = Tag();
}
void pull(Node *&p) {
    if (p->left && p->right) {
        p->info = p->left->info + p->right->info;
    } else if (p->left) {
        p->info = p->left->info;
    } else if (p->right) {
        p->info = p->right->info;
    } else {
        p->info = Info();
    }
}
void modify(Node *&p, int l, int r, int x, const Info &v) {
    if (!p) {
        p = new Node();
    }
    if (r - l == 1) {
        p->info = v;
        return;
    }
    push(p);
    int m = (l + r) / 2;
    if (x < m) {
        modify(p->left, l, m, x, v);
    } else {
        modify(p->right, m, r, x, v);
    }
    pull(p);
}
void modify(Node *&p, int x, const Info &v) {
    modify(p, 0, Node::n, x, v);
}
void rangeApply(Node *&p, int l, int r, int x, int y, const Tag &v) {
    if (l >= y || r <= x || !p) {
        return;
    }
    if (l >= x && r <= y) {
        apply(p, v);
        return;
    }
    push(p);
    int m = (l + r) / 2;
    rangeApply(p->left, l, m, x, y, v);
    rangeApply(p->right, m, r, x, y, v);
}
void rangeApply(Node *&p, int l, int r, const Tag &v) {
    rangeApply(p, 0, Node::n, l, r, v);
}
void merge(Node *&p, Node *q, int l, int r) {
    if (!q) {
        return;
    }
    if (!p) {
        p = q;
        return;
    }
    if (r - l == 1) {
        p->info = p->info + q->info;
        return;
    }
    push(p);
    push(q);
    int m = (l + r) / 2;
    merge(p->left, q->left, l, m);
    merge(p->right, q->right, m, r);
    pull(p);
}
void merge(Node *&p, Node *q) {
    merge(p, q, 0, Node::n);
}
Info query(Node *&p, int l, int r, int x) {
    if (!p) {
        return Info();
    }
    if (r - l == 1) {
        return p->info;
    }
    push(p);
    int m = (l + r) / 2;
    if (x < m) {
        return query(p->left, l, m, x);
    } else {
        return query(p->right, m, r, x);
    }
}
Info query(Node *&p, int x) {
    return query(p, 0, Node::n, x);
}
Info rangeQuery(Node *&p, int l, int r, int x, int y) {
    if (l >= y || r <= x || !p) {
        return Info();
    }
    if (l >= x && r <= y) {
        return p->info;
    }
    push(p);
    int m = (l + r) / 2;
    return rangeQuery(p->left, l, m, x, y) + rangeQuery(p->right, m, r, x, y);
}
Info rangeQuery(Node *&p, int l = 0, int r = Node::n) {
    return rangeQuery(p, 0, Node::n, l, r);
}

int main() {
    std::ios::sync_with_stdio(false);
    std::cin.tie(nullptr);

    int n;
    std::cin >> n;

    std::vector<int> a(n);
    for (auto &a : a) {
        std::cin >> a;
    }

    auto v = a;
    std::sort(v.begin(), v.end());
    v.erase(std::unique(v.begin(), v.end()), v.end());

    Node::n = v.size();
    for (auto &a : a) {
        a = std::lower_bound(v.begin(), v.end(), a) - v.begin();
    }

    std::vector<int> lson(n), rson(n);
    for (int i = 0; i < n; i++) {
        std::cin >> lson[i] >> rson[i];
        lson[i]--;
        rson[i]--;
    }

    Z ans = 0;
    auto dfs = [&](auto &&self, int u) -> Node * {
        if (u == -1) {
            return nullptr;
        }
        Node *cur = new Node();
        Node *left = self(self, lson[u]);
        Node *right = self(self, rson[u]);
        if (lson[u] != -1 && rson[u] != -1 && a[rson[u]] <= a[u]) {
            rangeApply(left, 0, Node::n, {rangeQuery(right, a[lson[u]], Node::n).x + 1});
        }
        if (lson[u] != -1 && a[lson[u]] <= a[u]) {
            merge(cur, left);
        }
        if (rson[u] != -1 && a[rson[u]] <= a[u]) {
            merge(cur, right);
        }
        modify(cur, a[u], {query(cur, a[u]).x + 1});
        ans += rangeQuery(cur).x;
        return cur;
    };
    dfs(dfs, 0);

    std::cout << ans << "\n";

    return 0;
}

F

首先注意到,如果选择超过 个人的话,答案一定不优。

考虑二分答案,记当前二分的答案为 ,可以将式子转化为:

存在一个显然的建图,即每个点拆成 个时刻,第 个时刻向第 连一条价值为 ,代价为 的边,然后对于给定的 条边,其中第 条边由 的第 个时刻向 的第 个时刻连一条价值为 ,代价为 的边。建图后计算所有 的最长路即可。

特别需要注意的是,当 时,可能会存在正权环,因此需要使用 算法。可以证明在本题的数据范围下,当答案为 时一定不存在正权环,因此二分上界设为 即可。

此外,由于不保证所有的点都可以走到 ,可能会存在一个走不到 的正权环,因此需要提前建出反图,把走不到 的点排除。

但是朴素建图存在 个点, 条边,复杂度为 ,无法通过。

注意到在朴素建图中,实际上最多只有 个点是有意义的,即对于每条边, 的出发时刻和 的到达时刻,以及所有起点的 时刻。

因此可以将复杂度优化至

另外本题时限较为紧张,如果不重新建图,而是在每次二分的时候枚举所有的边重新计算边权的话,最坏情况下存在 次对 的非常数取模运算,在实现时需要注意。

#include <bits/stdc++.h>

using u32 = unsigned int;
using i64 = long long;
using u64 = unsigned long long;
using f64 = double;

constexpr f64 inf = 1E18;

int main() {
    std::ios::sync_with_stdio(false);
    std::cin.tie(nullptr);
    std::cout << std::fixed << std::setprecision(7);

    int n, m, h, k, a, b;
    std::cin >> n >> m >> h >> k >> a >> b;

    std::vector<int> s(k);
    for (auto &s : s) {
        std::cin >> s;
        s--;
    }

    std::vector<std::array<int, 5>> e(m);
    for (auto &[u, v, s, t, w] : e) {
        std::cin >> u >> v >> s >> t >> w;
        u--;
        v--;
    }

    std::vector<std::vector<int>> g(n);
    for (auto [u, v, s, t, w] : e) {
        g[v].push_back(u);
    }

    std::vector<int> ok(n);
    auto dfs = [&](auto &&self, int u) -> void {
        if (ok[u]) {
            return;
        }
        ok[u] = 1;
        for (auto v : g[u]) {
            self(self, v);
        }
    };
    dfs(dfs, n - 1);

    std::vector<std::vector<int>> c(n);
    for (auto [u, v, s, t, w] : e) {
        if (!ok[u] || !ok[v]) {
            continue;
        }
        c[u].push_back(s);
        c[v].push_back((s + t) % h);
    }

    int tot = 0;
    for (int i = 0; i < n; i++) {
        if (!ok[i]) {
            continue;
        }
        c[i].push_back(0);
        std::sort(c[i].begin(), c[i].end());
        c[i].erase(std::unique(c[i].begin(), c[i].end()), c[i].end());
        tot += c[i].size();
    }

    int cur = 0;
    std::vector<int> id(n);
    std::vector<std::vector<std::tuple<int, f64, f64>>> adj(tot);
    for (int i = 0; i < n; i++) {
        id[i] = cur;
        for (int j = 0; j < c[i].size(); j++) {
            int k = (j + 1) % c[i].size();
            int x = (c[i][k] - c[i][j] + h) % h;
            adj[id[i] + j].emplace_back(id[i] + k, 0, 1.0 * a * x);
        }
        cur += c[i].size();
    }

    for (auto [u, v, s, t, w] : e) {
        if (!ok[u] || !ok[v]) {
            continue;
        }
        int i = std::lower_bound(c[u].begin(), c[u].end(), s) - c[u].begin();
        int j = std::lower_bound(c[v].begin(), c[v].end(), (s + t) % h) - c[v].begin();
        adj[id[u] + i].emplace_back(id[v] + j, w, 1.0 * b * t);
    }

    auto check = [&](f64 cur) {
        std::vector<f64> dis(tot, -inf);
        for (auto s : s) {
            dis[id[s]] = 0;
        }
        bool done = false;
        for (int i = 0; i < tot; i++) {
            done = true;
            for (int u = 0; u < tot; u++) {
                for (auto [v, num, den] : adj[u]) {
                    if (dis[v] < dis[u] + num - cur * den) {
                        dis[v] = dis[u] + num - cur * den;
                        done = false;
                    }
                }
            }
            if (done) {
                break;
            }
        }
        if (!done) {
            return true;
        }
        for (int i = 0; i < c[n - 1].size(); i++) {
            if (dis[id[n - 1] + i] >= 0) {
                return true;
            }
        }
        return false;
    };

    f64 lo = 0, hi = 1E9;
    for (int i = 0; i < 64; i++) {
        auto m = (lo + hi) / 2;
        if (check(m)) {
            lo = m;
        } else {
            hi = m;
        }
    }

    std::cout << lo << "\n";

    return 0;
}

G

对于修改操作,令交换的两个数为

很容易发现只有 这一段区间内可能会影响答案。

证明:

​ 对于 的部分, 一定包含 两个数,所以他们交换顺序并不会改变答案

​ 对于 的部分,由于 的答案只和 的位置有关,所以他们交换顺序并不会改变答案

所以对于一次修改,我们只需要重新算一下交换后 这一段的贡献即可

同理求出 必须包含的区间 ,然后考虑 的位置 我们令这个位置为 , 分类讨论一下

  • 如果,则此时的 一定不合法,++即可

  • 如果 则当前有贡献的区间为左端点在右端点在的所有区间

    考虑怎么算答案,容斥一下,我们令

    此时的答案

    然后考虑怎么计算

    为 长度为 时的答案

    很容易发现

    这是一个等差乘等差乘等比,类似于 做两次错位相减即可,然后预处理的幂次方和逆元可以O(1)求出答案

  • 如果 则当前有贡献的区间为左端点在右端点在的所有区间

注意单独算一下的情况

​ 复杂度为

#include <bits/stdc++.h>
#define int long long
using namespace std;
typedef long long ll;
int read() {
    int x = 0, f = 1;
    char c = getchar();
    while(c < '0' || c > '9')
        f = (c == '-') ? -1 : 1, c = getchar();
    while(c >= '0' && c <= '9')
        x = x * 10 + c - 48, c = getchar();
    return f * x;
}
const int N = 5e5 + 10;
const int mod = 998244353;
const int phi = 998244352;
const int K = 1 << 16;
int a[N];
int ksm(int a, int b) {
    int ans = 1;
    while(b) {
        if(b & 1)
            ans = ans * a % mod;
        a = a * a % mod;
        b >>= 1;
    }
    return ans;
}
int vis[N], pw[3 * N], jc[3 * N];
int Inv[N * 3];
void init() {
    pw[0] = 1;
    for(int i = 0; i < 3 * N; i++)
        pw[i + 1] = 2 * pw[i] % mod;
    int x = 1;
    for(int i = 1; i <= K; i++)
        x = x * 2 % mod;
    jc[0] = 1;
    for(int i = 1; i < 3 * N; i++)
        jc[i] = jc[i - 1] * x % mod;
    for(int i = 0; i < 3 * N; i++)
        Inv[i] = ksm((pw[i] - 1 + mod) % mod, mod - 2);
}
int pow2(int x) {
    return jc[x >> 16] * pw[x & (K - 1)] % mod;
}
int calc(int x, int mex) {
    if(x <= 0)
        return 0;
    int inv = Inv[mex];
    int k = pow2((x + 1) * mex % phi);
    int p1 = k * (2 - x + mod) % mod;
    int p2 = pw[2 * mex] * (x - 2 + mod) % mod;
    int p3 = 2 * (k - pw[3 * mex] + mod) * inv % mod;
    int s2 = (p1 - p2 + p3 + mod) * inv % mod;
    p1 = k * x % mod;
    p2 = pw[mex] * x % mod;
    p3 = s2;
    int s = (p1 - p2 - p3 + 2 * mod) * inv % mod;
    return s;
}
int sum[N];
int L[N], R[N];
signed main() {
    int n = read();
    for(int i = 1; i <= n; i++)
        a[i] = read(), vis[a[i]] = i;
    init();
    int l = n, r = 0, mex = 1;
    int ans = n * (n + 1) % mod;
    ans = ans * ksm(2, n * (n + 1)) % mod;
    sum[0] = (calc(vis[1] - 1, 1) + calc(n - vis[1], 1)) % mod;
    while(mex < n) {
        l = min(l, vis[mex]);
        r = max(r, vis[mex]);
        L[mex] = l;
        R[mex] = r;
        sum[mex] = sum[mex - 1];
        while(vis[mex + 1] <= r && vis[mex + 1] >= l) {
            mex++;
            sum[mex] = sum[mex - 1];
            L[mex] = l;
            R[mex] = r;
        }
        if(mex == n)
            break;
        int p = vis[mex + 1];
        int L = 1, R = n;
        if(p < l)
            L = p + 1;
        else
            R = p - 1;
        int js = calc(R - L + 1, mex + 1);
        js = (js - calc(r - L, mex + 1) - calc(R - l, mex + 1) + 2 * mod) % mod;
        js = (js + calc(r - l - 1, mex + 1)) % mod;
        ans = (ans + (mex + 1) * js) % mod;
        sum[mex] = (sum[mex] + (mex + 1) * js) % mod;
        mex++;
    }
    ans = (ans + sum[0]) % mod;
    int q = read();
    while(q--) {
        int a = read(), b = read();
        int A = a, B = b;
        if(a == b) {
            cout << ans << "\n";
            continue;
        }
        if(a > b)
            swap(a, b);
        swap(vis[a], vis[b]);
        if(b == n)
            b--;
        int Ans = ans;
        int mex = a - 1, l = L[mex], r = R[mex];
        if(a == 1) {
            Ans = (Ans - sum[0] + mod) % mod;
            Ans = (Ans + (calc(vis[1] - 1, 1) + calc(n - vis[1], 1)) % mod) % mod;
            a++;
            mex = 1, l = vis[1], r = vis[1];
        }
        Ans = (Ans - (sum[b] - sum[mex - 1] + mod) % mod + mod) % mod;
        while(mex <= b) {
            l = min(l, vis[mex]);
            r = max(r, vis[mex]);
            while(vis[mex + 1] <= r && vis[mex + 1] >= l) {
                mex++;
                if(mex == b + 1)
                    break;
            }
            if(mex == b + 1)
                break;
            int p = vis[mex + 1];
            int L = 1, R = n;
            if(p < l)
                L = p + 1;
            else
                R = p - 1;
            int js = calc(R - L + 1, mex + 1);
            js = (js - calc(r - L, mex + 1) - calc(R - l, mex + 1) + 2 * mod) % mod;
            js = (js + calc(r - l - 1, mex + 1)) % mod;
            Ans = (Ans + (mex + 1) * js) % mod;
            mex++;
        }
        swap(vis[A], vis[B]);
        cout << Ans << "\n";
    }
}