题目地址

题目链接

题解

二分答案,那么大于答案的路径都需要有一条公共边,maxlen-val>=二分出来的x。val是边权。

考虑树剖,对每条大于答案的路径都+1(线段树里),枚举边,如果(线段树中的)值==大于答案的边数,那么对他们取max。

复杂度\(O((nlognlogn+m)logn)\)(可能不是特别准确因为没写树剖,不过是三个log的没错)

卡常?不,想想树剖的这两个log我们拿来干啥,给路径的链+1。有没有什么复杂度更低的方法?

考虑树上差分。

对每条大于答案的路径差分标记,dfs一遍统计答案,对被标记的次数等于 大于答案的路径条数 的边的边权取max,并对所有路径取max。

check判断max路径-max边是否大于二分出的x即可

复杂度\(O((n+m)logn)\)

#include <bits/stdc++.h>
using namespace std;

#define in(x) (x = read())
inline int read() {
    int x = 0, f = 1; char c = getchar();
    while(c < '0' || c > '9') (c == '-') && (f = -1), c = getchar();
    while(c >= '0' && c <= '9') x = x * 10 + c - '0', c = getchar();
    return x * f;
}

#define N 300010
#define ll long long
int n = read(), m = read();
int head[N], cnt;
int d[N], val[N];
int dep[N], top[N], fa[N], siz[N];
struct edge {
    int to, nxt, v;
}e[N<<1];
struct Node {
    int x, y, lca;
    int len;
}q[N];

void ins(int u, int v, int w) {
    e[++cnt] = (edge) {v, head[u], w};
    head[u] = cnt;
}

void dfs1(int u) {
    siz[u] = 1;
    for(int i = head[u]; i; i = e[i].nxt) {
        if(e[i].to == fa[u]) continue;
        fa[e[i].to] = u;
        dep[e[i].to] = dep[u] + 1;
        val[e[i].to] = e[i].v;
        d[e[i].to] = d[u] + e[i].v;
        dfs1(e[i].to);
        siz[u] += siz[e[i].to];
    }
}

void dfs2(int u, int topf) {
    top[u] = topf;
    int k = 0;
    for(int i = head[u]; i; i = e[i].nxt) {
        if(e[i].to == fa[u]) continue;
        if(siz[e[i].to] > siz[k]) k = e[i].to;
    }
    if(!k) return;
    dfs2(k, topf);
    for(int i = head[u]; i; i = e[i].nxt) {
        if(e[i].to == fa[u] || e[i].to == k) continue;
        dfs2(e[i].to, e[i].to);
    }
}

int lca(int x, int y) {
    while(top[x] != top[y]) {
        if(dep[top[x]] < dep[top[y]]) swap(x, y);
        x = fa[top[x]];
    }
    if(dep[x] > dep[y]) swap(x, y);
    return x;
}

int res = 0, tot = 0, mx = 0;
void dfs(int u) {
    for(int i = head[u]; i; i = e[i].nxt) {
        if(e[i].to == fa[u]) continue;
        dfs(e[i].to);
        d[u] += d[e[i].to];
    }
    if(d[u] == tot) res = max(res, val[u]);
}

bool check(ll x) {
    memset(d, 0, sizeof(d));
    res = 0; tot = 0; mx = 0;
    for(int i = 1; i <= m; ++i) {
        if(q[i].len <= x) continue;
        d[q[i].x]++; d[q[i].y]++; d[q[i].lca] -= 2;
        ++tot;
        mx = max(q[i].len, mx);
    }
    dfs(1);
    if(mx - res > x) return 0;
    return 1;
}

int main(){
    for(int i = 1, u, v, w; i < n; ++i) {
        in(u), in(v), in(w);
        ins(u, v, w), ins(v, u, w);
    }
    dfs1(1); dfs2(1, 1);
    int Mx = 0;
    for(int i = 1; i <= m; ++i) {
        in(q[i].x); in(q[i].y);
        q[i].lca = lca(q[i].x, q[i].y);
        q[i].len = d[q[i].x] - d[q[i].lca] + d[q[i].y] - d[q[i].lca];
        Mx = max(Mx, q[i].len);
    }
    int l = 1, r = Mx, ans = Mx;
    while(l <= r) {
        int mid = (l + r) >> 1;
        if(check(mid)) ans = mid, r = mid - 1;
        else l = mid + 1;
    }
    printf("%d\n", ans);
}