提示

A. 考虑单次操作中期望的变化量。

B. 博弈基础。

C. 你可能在找:树的直径。此外就是计数逻辑与代码能力了。

D. 第一个数和第 个数有何区别?连乘无法下手能否转化一下?

E. 考虑合法区间最小能到什么程度。再考虑如何进一步维护。

F. dp 状态不难想,提速的关键在于将一些有共性的状态一并转移。

A 和的期望

在每一步中,每一个数被选的概率相同,因此每一步期望的增量也相同。

// 点击上方选项卡查看代码
#include <bits/stdc++.h>

constexpr int P = 998'244'353;
int mod_pow(int a, int b, int p = P) {
    int res = 1;
    for (; b; b /= 2, a = 1ll * a * a % p) {
        if (b & 1) {
            res = 1ll * res * a % p;
        }
    }
    return res;
}

int main() {
    std::cin.tie(nullptr)->sync_with_stdio(false);

    int n;
    std::cin >> n;

    int sum = 0;
    for (int i = 0; i < n; i++) {
        int x;
        std::cin >> x;
        sum = (sum + x) % P;
    }
    int once = 1ll * sum * mod_pow(n, P - 2, P) % P;
    sum = 0;
    for (int i = 0; i < n; i++) {
        sum = (sum + once) % P;
        std::cout << sum << " \n"[i + 1 == n];
    }

    return 0;
}
P = 998244353

n = int(input())
invn = pow(n, P - 2, P)
val = sum(map(int, input().split()))
x = invn * val % P
print(*[x * (k + 1) % P for k in range(n)])

B 树上博弈

(自己分数) 就是要在一步之后 (对方分数),即选择对于对方最劣的方案,因此需要求子树最小值,这可以拓扑转移,复杂度

// 点击上方选项卡查看代码
#include <bits/stdc++.h>

using int64 = long long;
constexpr int N = (int) 2e5 + 10;
constexpr int64 inf = 1e18;

int a[N];
int64 val[N];
std::vector<int> edge[N];

int64 dfs(int u) {
    if (edge[u].empty()) {
        val[u] = a[u];
        return a[u];
    }
    int64 res = -inf;
    val[u] = -inf;
    for (int v : edge[u]) {
        dfs(v);
        val[u] = std::max(val[u], val[v]);
        res = std::max(res, val[v]);
    }
    res = a[u] - res;
    val[u] = std::max(val[u], res);
    return res;
}

int64 solve() {
    int n;
    std::cin >> n;

    for (int i = 2; i <= n; i++) {
        std::cin >> a[i];
    }

    for (int i = 2; i <= n; i++) {
        int p;
        std::cin >> p;
        edge[p].push_back(i);
    }

    int64 ans = - dfs(1);
    for (int i = 1; i <= n; i++) {
        edge[i].clear();
    }

    return ans;
}

int main() {
    std::cin.tie(nullptr)->sync_with_stdio(false);

    int T;
    std::cin >> T;

    while (T--) {
        std::cout << solve() << '\n';
    }

    return 0;
}
import sys
from types import GeneratorType


# ref: https://ac.nowcoder.com/acm/contest/view-submission?submissionId=60451158
# in case of `recursion limit exceed'
def bootstrap(f, stack=[]):
    def wrappedfunc(*args, **kwargs):
        if stack:
            return f(*args, **kwargs)
        else:
            to = f(*args, **kwargs)
            while True:
                if type(to) is GeneratorType:
                    stack.append(to)
                    to = next(to)
                else:
                    stack.pop()
                    if not stack:
                        break
                    to = stack[-1].send(to)
            return to
    return wrappedfunc


input = sys.stdin.readline

T = int(input())
for _ in range(T):
    n = int(input())
    a = [0] + list(map(int, input().split()))

    edge = [[] for k in range(n)]
    p = [0] + list(map(lambda x: int(x) - 1, input().split()))
    for k in range(1, n):
        edge[p[k]].append(k)

    inf = 10 ** 9
    val = [0] * n

    @bootstrap
    def dfs(u: int) -> int :
        if len(edge[u]) == 0:
            val[u] = a[u]
            yield val[u]
        val[u] = -inf
        for v in edge[u]:
            yield dfs(v)
            val[u] = max(val[u], val[v])
        tmp = a[u] - val[u]
        val[u] = max(val[u], tmp)
        yield tmp

    print(-dfs(0))

C 树的联结

考虑对于两种点对分别求解:

  • 同一棵树内的点对:单轮 dfs 统计。
  • 分别在两棵树上的点对:考虑求出每个点所在树上的最远点的距离,而由于树上任意一点的最远点集一定含有树上任一直径中的至少一点(尽管直径、最远点可能不唯一,但最远点的距离唯一),因此求得树的一条直径就能统计答案。
// 点击上方选项卡查看代码
#include <bits/stdc++.h>

int main() {
    std::cin.tie(nullptr)->sync_with_stdio(false);

    int n, m;
    std::cin >> n >> m;

    std::vector<std::vector<int>> edge(n + m);

    for (int i = 1; i <= n + m - 2; i++) {
        int u, v;
        std::cin >> u >> v;
        u--; v--;
        edge[u].push_back(v);
        edge[v].push_back(u);
    }

    constexpr int inf = (~0U) >> 2;

    auto get = [&] (int s) -> std::pair<int, std::vector<int>> {
        std::vector<int> dis(n + m, inf);
        std::queue<int> que;
        que.push(s);
        dis[s] = 0;
        int u;
        while (not que.empty()) {
            u = que.front();
            que.pop();
            for (int v : edge[u]) {
                if (dis[v] == inf) {
                    dis[v] = dis[u] + 1;
                    que.push(v);
                }
            }
        }
        return std::make_pair(u, dis);
    };

    int s1 = get(0).first;
    auto [t1, d1] = get(s1);
    auto [_1, dd1] = get(t1);

    int s2 = get(n).first;
    auto [t2, d2] = get(s2);
    auto [_2, dd2] = get(t2);

    long long sum1 = 0, sum2 = 0;
    for (int i = 0; i < n; i++) {
        d1[i] = std::max(d1[i], dd1[i]);
        sum1 += d1[i];
    }
    for (int i = n; i < n + m; i++) {
        d2[i] = std::max(d2[i], dd2[i]);
        sum2 += d2[i];
    }

    long long ans = sum1 * m + sum2 * n + 1LL * n * m;

    std::function<std::pair<long long, int>(int, int)> dfs = [&] (int u, int p) -> std::pair<long long, int> {
        std::pair<long long, int> res(0LL, 0);
        auto &[val, tot] = res;
        std::vector<int> ch;
        for (int v : edge[u]) {
            if (v == p) {
                continue;
            }
            ch.push_back(v);
            auto [val1, tot1] = dfs(v, u);
            ans += val1 + tot1; // u <--> son(vi)
            ans += val * tot1 + val1 * tot + 2LL * tot * tot1; // son(v1..i-1) <--> son(vi)
            val += val1;
            tot += tot1;
        }
        val += tot;
        tot += 1;
        return res;
    };

    dfs(0, -1);
    dfs(n, -1);

    std::cout << ans << '\n';

    return 0;
}
import sys
from types import GeneratorType


# ref: https://ac.nowcoder.com/acm/contest/view-submission?submissionId=60451158
# in case of `recursion limit exceed'
def bootstrap(f, stack=[]):
    def wrappedfunc(*args, **kwargs):
        if stack:
            return f(*args, **kwargs)
        else:
            to = f(*args, **kwargs)
            while True:
                if type(to) is GeneratorType:
                    stack.append(to)
                    to = next(to)
                else:
                    stack.pop()
                    if not stack:
                        break
                    to = stack[-1].send(to)
            return to
    return wrappedfunc

input = sys.stdin.readline

n, m = map(int, input().split())
edge = [[] for k in range(n + m)]

for k in range(n + m - 2):
    u, v = map(lambda x: int(x) - 1, input().split())
    edge[u].append(v)
    edge[v].append(u)

inf = 10 ** 9

def get(s: int):
    dis = [inf] * (n + m)
    que = [0] * (n + m)
    head, tail = 0, 0
    
    dis[s] = 0
    que[tail] = s
    tail += 1

    while head < tail:
        u = que[head]
        head += 1
        for v in edge[u]:
            if dis[v] == inf:
                dis[v] = dis[u] + 1
                que[tail] = v
                tail += 1

    return u, dis

""" first tree """
s1, _ = get(0)
t1, d1 = get(s1)
_, dd1 = get(t1)

""" second tree """
s2, _ = get(n)
t2, d2 = get(s2)
_, dd2 = get(t2)

tr1_sum, tr2_sum = 0, 0
for k in range(n):
    tr1_sum += max(d1[k], dd1[k])
for k in range(n, n + m):
    tr2_sum += max(d2[k], dd2[k])

ans = tr1_sum * m + tr2_sum * n + n * m

@bootstrap
def dfs(u: int, p: int):
    val, tot = 0, 0
    for v in edge[u]:
        if v == p: continue
        val1, tot1 = yield dfs(v, u)
        global ans
        ans += val1 + tot1 + val * tot1 + val1 * tot + 2 * tot * tot1
        val += val1
        tot += tot1
    val += tot
    tot += 1
    yield val, tot

dfs(0, -1)
dfs(n, -1)

print(ans)

D 权值和 plus

不难发现两个数组的合法性是独立的,所以可以使用相同方法分别求解。

又由于每个位置上的数地位一致,于是考虑求出最后一个位置上取每个值的方案数,由于数组元素乘积一定,只需求出前 个数乘积为 的方案数。

时容斥求解即可。当 时,组成该乘积的所有数对 取余均不为 ,考虑对乘积取离散对数,将求积同余式化为求和同余式,这可以用多项式快速幂求解,复杂度

// 点击上方选项卡查看代码
#include <bits/stdc++.h>

using poly = std::vector<int>;

constexpr int P = 998'244'353;

int mod_pow(int a, int b, int p = P) {
    int res = 1;
    for (; b; b /= 2, a = 1ll * a * a % p) {
        if (b & 1) {
            res = 1ll * res * a % p;
        }
    }
    return res;
}

poly dft(poly f, int len) {
    f.resize(len);
    for (int i = 0, j = 0; i < len; i++) {
        if (i < j) std::swap(f[i], f[j]);
        for (int k = len >> 1; (j ^= k) < k; k >>= 1) {}
    }
    for (int h = 1; h < len; h *= 2) {
        constexpr int g = 3;
        int wn = mod_pow(g, (P - 1) / (2 * h));
        for (int L = 0; L < len; L += 2 * h) {
            for (int k = L, w = 1; k < L + h; k++, w = 1ll * w * wn % P) {
                int x = f[k], y = 1ll * f[k + h] * w % P;
                f[k] = (x + y) % P;
                f[k + h] = (x + P - y) % P;
            }
        }
    }
    return f;
}

poly idft(poly f, int len) {
    f = dft(f, len);
    std::reverse(f.begin() + 1, f.end());
    int invlen = P - (P - 1) / len;
    for (int i = 0; i < len; i++) {
        f[i] = 1ll * f[i] * invlen % P;
    }
    return f;
}

poly& operator*= (poly& f, poly g) {
    int n = f.size() - 1, m = g.size() - 1;
    int len = 1;
    while (len <= n + m) len *= 2;
    f = dft(f, len);
    g = dft(g, len);
    for (int i = 0; i < len; i++) {
        f[i] = 1ll * f[i] * g[i] % P;
    }
    f = idft(f, len);
    f.resize(n + m + 1);
    return f;
}

poly operator* (poly f, const poly &g) {
    f *= g;
    return f;
}

bool is_prime(int p) {
    if (p <= 1) return false;
    for (int x = 2; x * x <= p; x++) {
        if (p % x == 0) {
            return false;
        }
    }
    return true;
}

int primitive_root(int m) {
    if (m == 2) {
        return 1;
    }

    std::vector<int> factor;
    int t = m - 1;
    for (int x = 2; x * x <= t; x++) {
        if (t % x == 0) {
            factor.push_back(x);
            while (t % x == 0) t /= x;
        }
    }
    if (t > 1) {
        factor.push_back(t);
    }

    int g = 1;
    bool ok;
    do {
        g++;
        ok = true;
        for (int p : factor) {
            if (mod_pow(g, (m - 1) / p, m) == 1) {
                ok = false;
                break;
            }
        }
    } while (not ok);

    return g;
}

int main() {
    std::cin.tie(nullptr)->sync_with_stdio(false);

    constexpr int N = 1e5;
    int p, q, n, m, k;
    std::cin >> p >> q >> n >> m >> k;

    std::vector<int> S(q);
    for (int &x : S) {
        std::cin >> x;
    }
    std::sort(S.begin(), S.end());

    std::vector<int> cnt(p);
    std::vector<std::vector<int>> list(p);
    for (int i = 0; i < q; i++) {
        cnt[S[i] % p] ++;
        list[S[i] % p].push_back(i);
    }

    int g = primitive_root(p);
    std::vector<int> x2e(p), e2x(p);
    for (int x = 1, e = 0; e < p - 1; x = 1ll * x * g % p, e++) {
        assert(x2e[x] == 0);
        e2x[e] = x;
        x2e[x] = e;
    }

    poly a(p - 1, 0);
    for (int i = 1; i < p; i++) {
        a[x2e[i]] = cnt[i];
    }

    auto calc = [&] (poly f, const poly &b) {
        f *= b;
        for (int i = p - 1; i < (int) f.size(); i++) {
            f[i % (p - 1)] = (f[i % (p - 1)] + f[i]) % P;
        }
        f.resize(p - 1);
        return f;
    };

    poly f = { 1 };
    for (int b = k - 1; b; b /= 2, a = calc(a, a)) {
        if (b & 1) {
            f = calc(f, a);
        }
    }

    auto get_res = [&] (int val) -> std::vector<int> {
        std::vector<int> res(q);
        if (val == 0) {
            int tot = mod_pow(q - cnt[0], k - 1, P);
            for (int i : list[0]) {
                res[i] = (res[i] + tot) % P;
            }
            tot = (mod_pow(q, k - 1, P) - tot + P) % P;
            for (int i = 0; i < q; i++) {
                res[i] = (res[i] + tot) % P;
            }
        } else {
            int target = x2e[val] + p - 1;
            for (int e = 0; e < p - 1; e++) {
                int mul = e2x[(target - e) % (p - 1)];
                for (int i : list[mul]) {
                    res[i] += f[e];
                }
            }
        }
        return res;
    };

    auto resn = get_res(n);
    auto resm = get_res(m);
    
    int psn = 0, psm = 0;
    int ans = 0;
    for (int i = q - 1; i >= 0; i--) {
        int res = (1ll * resn[i] * psm + 1ll * resm[i] * psn + 1ll * resn[i] * resm[i]) % P;
        ans = (ans + 1ll * S[i] * res) % P;
        psn = (psn + resn[i]) % P;
        psm = (psm + resm[i]) % P;
    }
    ans = 1ll * ans * k % P;
    std::cout << ans << '\n';

    return 0;
}
import sys

input = sys.stdin.readline

P = 998244353
G = 3

def dft(f, n):
    k = 0
    for i in range(n):
        if i < k:
            f[i], f[k] = f[k], f[i]
        j = n
        while True:
            j = j >> 1
            k ^= j
            if k >= j:
                break
    h = 1
    while h < n:
        wn = pow(G, (P - 1) // (2 * h), P)
        for L in range(0, n, 2 * h):
            w = 1
            for k in range(L, L + h):
                x, y = f[k], f[k + h] * w % P
                w = w * wn % P
                f[k], f[k + h] = (x + y) % P, (x - y + P) % P
        h *= 2

def idft(f, n):
    dft(f, n)
    for k in range(1, n // 2):
        f[k], f[n - k] = f[n - k], f[k]
    invn = P - (P - 1) // n
    for k in range(n):
        f[k] = f[k] * invn % P

def mul(f, g):
    g = g.copy()

    n = 1
    while n <= len(f) + len(g) - 2:
        n *= 2
    f += [0] * (n - len(f))
    g += [0] * (n - len(g))

    dft(f, n)
    dft(g, n)
    for k in range(n):
        f[k] = f[k] * g[k] % P
    idft(f, n)
    f = f[0 : n]

def is_prime(x):
    for p in range(2, x):
        if p * p > x:
            break
        if x % p == 0:
            return False
    return x >= 2

def primitive_root(m):
    assert is_prime(m)
    factor = []
    t = m - 1
    for p in range(2, t):
        if p * p > t:
            break
        if t % p == 0:
            factor.append(p)
            while t % p == 0:
                t //= p
    if t > 1:
        factor.append(t)

    g = 1
    while True:
        for p in factor:
            if pow(g, (m - 1) // p, m) == 1:
                break
        else:
            break
        g += 1

    return g

p, q, n, m, k = map(int, input().split())
S = list(sorted(map(int, input().split())))
vec = [[] for k in range(p)]
for i in range(q):
    vec[S[i] % p].append(i)

g = primitive_root(p)
x2e = [0] * p
e2x = [0] * p

x = 1
for e in range(p - 1):
    assert x2e[x] == 0
    e2x[e] = x
    x2e[x] = e
    x = x * g % p

a = [0] * (p - 1)
for i in range(1, p):
    a[x2e[i]] = len(vec[i])

def calc(f, g):
    mul(f, g)
    for i in range(p - 1, len(f)):
        f[i % (p - 1)] = (f[i % (p - 1)] + f[i]) % P
    while len(f) > p - 1:
        f.pop()

f = [1]
b = k - 1
while b > 0:
    if (b & 1) == 1:
        calc(f, a)
    b //= 2
    calc(a, a)

def get_res(val):
    res = [0] * q
    if val == 0:
        tot = pow(q - len(vec[0]), k - 1, P)
        for i in vec[0]:
            res[i] = tot
        tot = (pow(q, k - 1, P) - tot + P) % P
        for i in range(q):
            res[i] = (res[i] + tot) % P
    else:
        target = x2e[val] + p - 1
        for e in range(p - 1):
            mul = e2x[(target - e) % (p - 1)]
            for i in vec[mul]:
                res[i] = f[e]
    return res

resn = get_res(n)
resm = get_res(m)
psn, psm = 0, 0
ans = 0

for i in range(q)[::-1]:
    res = (resn[i] * psm + resm[i] * psn + resn[i] * resm[i]) % P
    ans = (ans + S[i] * res) % P
    psn = (psn + resn[i]) % P
    psm = (psm + resm[i]) % P

ans = ans * k % P
print(ans)

E 寻找中位数 plus

考虑将 的元素赋权值 (排序后在中位数左侧),将 的元素赋权值 (排序后在中位数右侧),而 的元素可以为 ,那么一个区间合法其实就是,首先要有一个 ,然后把这个 拿出来,剩下的数的权值和为

手玩一些例子可以得出结论:若 合法,那么必然存在包含于 的合法区间 满足,除了 两个位置的元素可能为 外,其余元素均不等于

事实上,对于合法区间 ,若存在 使得 ,则区间 与区间 必有至少一个满足权值和非负。

由于两种情况可以对称地处理,不妨假设区间 权值和非负,那么必然存在 满足区间 的权值和为 ,令 就得到含于合法区间 的一个合法子区间

显然只要 中存在等于 的元素,就可以不断重复上述操作,最终得到结论中所描述的合法子区间。

这启示我们维护相邻的 之间的信息,于是考虑维护如下区间信息:

  • 合法子区间的存在性;
  • (如果包含 )第一个 左侧元素的权值和与权值后缀和的最大值;
  • (如果包含 )最后一个 右侧元素的权值和与权值前缀和的最大值。

相邻两个区间的信息能在 时间合并,可用线段树维护,只涉及单点修改和区间查询操作,复杂度

// 点击上方选项卡查看代码
#include <bits/stdc++.h>

constexpr int inf = (~0u) >> 2;
constexpr int N = (int) 2e5 + 10;

struct node {
    int lsum = 0, lmax = -inf;
    int rsum = 0, rmax = -inf;
    bool has_k = 0, valid = 0;
    node() {}
    node(int x) {
        lmax = rmax = (x >= 0 ? +1 : -1);
        lsum = rsum = (x == 0 ? 0 : lmax);
        has_k = (x == 0);
        valid = 0;
    }
} tr[N * 4], empty_node;

#define lc (u * 2 + 1)
#define rc (u * 2 + 2)
#define mid (l + (r - l) / 2)

node operator+ (const node &L, const node &R) {
    node M;

    M.has_k = L.has_k or R.has_k;
    M.valid = L.valid or R.valid;

    if (L.has_k) {
        M.lsum = L.lsum;
        M.lmax = L.lmax;
        M.valid |= (L.rsum + R.lmax >= 0);
    } else {
        M.lsum = L.lsum + R.lsum;
        M.lmax = std::max(L.lmax, L.lsum + R.lmax);
    }

    if (R.has_k) {
        M.rsum = R.rsum;
        M.rmax = R.rmax;
        M.valid |= (R.lsum + L.rmax >= 0);
    } else {
        M.rsum = R.rsum + L.rsum;
        M.rmax = std::max(R.rmax, R.rsum + L.rmax);
    }

    return M;
}

void build(int u, int l, int r, const std::vector<int> &a) {
    if (r - l == 1) {
        tr[u] = node(a[l]);
    } else {
        build(lc, l, mid, a);
        build(rc, mid, r, a);
        tr[u] = tr[lc] + tr[rc];
    }
}

void modify(int u, int l, int r, int pos, int val) {
    if (r - l == 1) {
        assert(l == pos);
        tr[u] = node(val);
    } else {
        if (pos < mid) {
            modify(lc, l, mid, pos, val);
        } else {
            modify(rc, mid, r, pos, val);
        }
        tr[u] = tr[lc] + tr[rc];
    }
}

node query(int u, int l, int r, int lo, int hi) {
    if (lo <= l and r <= hi) {
        return tr[u];
    } else if (hi <= l or r <= lo) {
        return empty_node;
    } else {
        return query(lc, l, mid, lo, hi) + query(rc, mid, r, lo, hi);
    }
}

int main() {
    std::cin.tie(nullptr)->sync_with_stdio(false);

    int n, k, q;
    std::cin >> n >> k >> q;

    std::vector<int> a(n);
    for (int i = 0; i < n; i++) {
        std::cin >> a[i];
        a[i] -= k;
    }

    build(0, 0, n, a);

    while (q--) {
        int op, l, r;
        std::cin >> op >> l >> r;
        l--;
        if (op == 1) {
            modify(0, 0, n, l, r - k);
        } else {
            std::cout << (query(0, 0, n, l, r).valid ? "YES\n" : "NO\n");
        }
    }

    return 0;
}
import sys

def input():
    return sys.stdin.readline().strip()

inf = 10 ** 9

class Node:
    def __init__(self, x: int):
        if x >= 0:
            self.lmax = +1
        else:
            self.lmax = -1
        if x == 0:
            self.lsum = 0
        else:
            self.lsum = self.lmax
        self.rmax, self.rsum = self.lmax, self.lsum
        self.has_k = (x == 0)
        self.valid = False

    def __add__(L, R):
        M = Node(0)
        M.lmax, M.rmax = -inf, -inf
        M.has_k = L.has_k or R.has_k
        M.valid = L.valid or R.valid

        if L.has_k:
            M.lsum, M.lmax = L.lsum, L.lmax
            M.valid = M.valid or (L.rsum + R.lmax >= 0)
        else:
            M.lsum = L.lsum + R.lsum
            M.lmax = max(L.lmax, L.lsum + R.lmax)

        if R.has_k:
            M.rsum, M.rmax = R.rsum, R.rmax
            M.valid = M.valid or (R.lsum + L.rmax >= 0)
        else:
            M.rsum = R.rsum + L.rsum
            M.rmax = max(R.rmax, R.rsum + L.rmax)
        
        return M

n, k, q = map(int, input().split())
a = list(map(lambda x: int(x) - k, input().split()))

empty_node = Node(-1)
empty_node.lsum, empty_node.rsum = 0, 0
empty_node.lmax, empty_node.rmax = -inf, -inf
empty_node.has_k, empty_node.valid = False, False

tr = [empty_node] * (4 * n)

def build(u, l, r):
    if r - l == 1:
        tr[u] = Node(a[l])
    else:
        mid = l + (r - l) // 2
        build(u * 2 + 1, l, mid)
        build(u * 2 + 2, mid, r)
        tr[u] = tr[u * 2 + 1] + tr[u * 2 + 2]

def modify(u, l, r, pos, val):
    if r - l == 1:
        tr[u] = Node(val)
    else:
        mid = l + (r - l) // 2
        if pos < mid:
            modify(u * 2 + 1, l, mid, pos, val)
        else:
            modify(u * 2 + 2, mid, r, pos, val)
        tr[u] = tr[u * 2 + 1] + tr[u * 2 + 2]

def query(u, l, r, lo, hi):
    if lo <= l and r <= hi:
        return tr[u]
    elif hi <= l or r <= lo:
        return empty_node
    else:
        mid = l + (r - l) // 2
        return query(u * 2 + 1, l, mid, lo, hi) + query(u * 2 + 2, mid, r, lo, hi)

build(0, 0, n)

for _ in range(q):
    op, op1, op2 = map(int, input().split())
    if op == 1:
        i, x = op1 - 1, op2 - k
        modify(0, 0, n, i, x)
    else:
        l, r = op1 - 1, op2
        if query(0, 0, n, l, r).valid:
            print("YES")
        else:
            print("NO")

F 十六度空间

首先预处理组合数和任二点间可以经过其他点的最短路与方案数,后者复杂度为

再容斥 预处理任二点间不能经过其它点的最短路与方案数,每个起点的复杂度为 ,总复杂度

记所给 个点构成的集合为 ,并定义 中所有曾到达过的点构成的集合为 且最后一个点是 的最短路长度与最短路方案数。

对于 个指定的起点(如 ),将其对应初始状态的值(如 )进行初始化。

在从 转移时(此时 ),从 出发到 的过程中不能经过 中的点(否则不符合状态定义),但可以经过 中的点(所以转移之前必须做处理)。

考虑对已达点集合为 的所有 个状态统一处理并转移:

  • 先处理经过 中的点的方案数,这可以用 解决,复杂度
  • 再枚举从 的所有边进行转移,从某个状态 出发,可能的起点是 中的点,可能的终点是不在 中的点,转移数为

故总的转移复杂度为

本题还有 的做法,由于常数很小,跑得也飞快。

// 点击上方选项卡查看代码
#include <bits/stdc++.h>

constexpr int P = 998244353;
constexpr int inf = 1e9;
constexpr int M = 16;
constexpr int N = 2 * M * (int) 1e5 + 10;

int fact[N], ifact[N], inv[N];

void init() {
    fact[0] = ifact[0] = 1;
    fact[1] = ifact[1] = inv[1] = 1;
    for (int i = 2; i < N; i++) {
        fact[i] = 1ll * fact[i - 1] * i % P;
        inv[i] = P - 1ll * P / i * inv[P % i] % P;
        ifact[i] = 1ll * ifact[i - 1] * inv[i] % P;
    }
}

using info = std::pair<int, int>;
info operator + (const info &a, const info &b) {
    if (a.first != b.first) {
        return std::min(a, b);
    } else {
        return info(a.first, (a.second + b.second) % P);
    }
}

info &operator += (info &a, const info &b) {
    return a = a + b;
}

int main() {
    std::cin.tie(nullptr)->sync_with_stdio(false);
    init();

    int n, m, k;
    std::cin >> n >> m >> k;

    using point = std::vector<int>;
    std::vector<point> v(m, point(n));
    for (auto &p : v) {
        for (auto &x : p) {
            std::cin >> x;
        }
    }

    std::vector edge(m, std::vector<info>(m, info(0, 1)));
    for (int i = 0; i < m; i++) {
        for (int j = 0; j < m; j++) {
            auto &[sum, res] = edge[i][j];
            for (int d = 0; d < n; d++) {
                int x = std::abs(v[i][d] - v[j][d]);
                res = 1ll * res * ifact[x] % P;
                sum += x;
            }
            res = 1ll * res * fact[sum] % P;
        }
    }

    auto calc = [&] (int from, int s) -> std::vector<info> {
        int left = s;
        std::vector<info> dp(m, { inf, 0 });
        while (left != 0) {
            int cur = -1;
            for (int i = 0; i < m; i++) {
                if ((((left >> i) & 1) and (cur == -1 or edge[from][i] < edge[from][cur]))) {
                    cur = i;
                }
            }
            left -= 1 << cur;
            dp[cur] += edge[from][cur];
            for (int i = 0; i < m; i++) {
                if (((left >> i) & 1)) {
                    dp[i] += info(dp[cur].first + edge[cur][i].first, 1ll * (P - dp[cur].second) * edge[cur][i].second % P);
                }
            }
        }
        return dp;
    };

    std::vector cg(m, std::vector<info>(m));
    for (int i = 0; i < m; i++) {
        cg[i] = calc(i, (1 << m) - 1 - (1 << i));
    }

    std::array<info, M> init;
    init.fill(info(inf, 0));
    std::vector<std::array<info, M>> dp(1 << m, init);
    for (int i = 0; i < k; i++) {
        int q;
        std::cin >> q;
        q--;
        dp[1 << q][q] = { 0, 1 };
    }

    for (int st = 1; st < (1 << m); st ++) {
        int left = st;
        while (left > 0) {
            int mi = -1;
            for (int i = 0; i < m; i++) {
                if (((left >> i) & 1) and (mi == -1 or dp[st][i] < dp[st][mi])) {
                    mi = i;
                }
            }
            left -= 1 << mi;
            for (int to = 0; to < m; to++) {
                if (not ((st >> to) & 1)) continue;
                int t = st | (1 << to);
                dp[t][to] += info(dp[st][mi].first + cg[mi][to].first, 1ll * dp[st][mi].second * cg[mi][to].second % P);
            }
        }

        for (int i = 0; i < m; i++) {
            if (!(st >> i & 1))continue;
            for (int to = 0; to < m; to++) {
                if ((st >> to) & 1) continue;
                int t = st | (1 << to);
                dp[t][to] += info(dp[st][i].first + cg[i][to].first, 1ll * dp[st][i].second * cg[i][to].second % P);
            }
        }
    }

    info res = { inf, 0 };
    for (int i = 0; i < m; i++) {
        res += dp[(1 << m) - 1][i];
    }
    std::cout << res.first << '\n';
    std::cout << res.second << '\n';

    return 0;
}
import sys

def input():
    return sys.stdin.readline().strip()

P = 998244353
inf = 10 ** 9

def init(n):
    global fact, ifact
    fact = [1] * (n + 1)
    for i in range(2, n + 1):
        fact[i] = fact[i - 1] * i % P
    ifact = fact.copy()
    ifact[n] = pow(fact[n], P - 1 - 1, P)
    for i in range(1, n + 1)[::-1]:
        ifact[i - 1] = ifact[i] * i % P

def add_info(a, b):
    if a[0] != b[0]:
        return min(a, b)
    else:
        return (a[0], (a[1] + b[1]) % P)

n, m, k = map(int, input().split())
v = [list(map(int, input().split())) for i in range(m)]
init(n * 2 * 10 ** 5)

edge = [[(0, 1)] * m for i in range(m)]
for i in range(m):
    for j in range(m):
        tot, res = 0, 1
        for d in range(n):
            x = abs(v[i][d] - v[j][d])
            tot += x
            res = res * ifact[x] % P
        res = res * fact[tot] % P
        edge[i][j] = (tot, res)

def calc(s, state):
    left = state
    dp = [(inf, 0)] * m
    while left > 0:
        t = -1
        for i in range(m):
            if ((left >> i) & 1) == 1 and (t == -1 or edge[s][i] < edge[s][t]):
                t = i

        left -= 1 << t
        dp[t] = add_info(dp[t], edge[s][t])
        for i in range(m):
            if ((left >> i) & 1) == 1:
                tmp = (dp[t][0] + edge[t][i][0], (P - dp[t][1]) * edge[t][i][1] % P)
                dp[i] = add_info(dp[i], tmp)
    return dp

cg = [calc(i, (1 << m) - 1 - (1 << i)) for i in range(m)]

dp = [[(inf, 0)] * m for st in range(1 << m)]
q = map(lambda x: int(x) - 1, input().split())
for i in q:
    dp[1 << i][i] = (0, 1)

for st in range(1, 1 << m):
    left = st
    while left > 0:
        s = -1
        for i in range(m):
            if (((left >> i) & 1) == 1 and (s == -1 or dp[st][i] < dp[st][s])):
                s = i

        left -= 1 << s
        for t in range(m):
            if ((st >> t) & 1) == 1:
                st1 = st | (1 << t)
                tmp = (dp[st][s][0] + cg[s][t][0], dp[st][s][1] * cg[s][t][1] % P)
                dp[st1][t] = add_info(dp[st1][t], tmp)

    for s in range(m):
        if ((st >> s) & 1) == 1:
            for t in range(m):
                if ((st >> t) & 1) != 1:
                    st1 = st | (1 << t)
                    tmp = (dp[st][s][0] + cg[s][t][0], dp[st][s][1] * cg[s][t][1] % P)
                    dp[st1][t] = add_info(dp[st1][t], tmp)

res = (inf, 0)
for i in range(m):
    res = add_info(res, dp[(1 << m) - 1][i])

print(*res, sep='\n')