ACM模版

描述

题解

这个题做法五花八门的,有的人用 +线 过的,有的大佬是用 过的,反正方法太多了,我也很懵逼……而我,用的是 LCA+ 搞得,这个解法真是个野路子,虽然 AC 了,但是花了我最后的两个多小时调试……

代码

#include <iostream>
#include <vector>
#include <algorithm>

using namespace std;

typedef long long ll;

const int MAXN = 1e5 + 7;

int A[MAXN];
int B[MAXN];
int pre[MAXN];
int id[MAXN];
int d[MAXN];
int ls[MAXN * 20];
int rs[MAXN * 20];
int T[MAXN * 20];
int anc[MAXN][25];
ll cnt[MAXN * 20];
vector<int> vi[MAXN];

int ql, qr, tol, tot;

void init(int n)
{
    tot = 1;
    tol = 0;
    for (int i = 1; i <= n; i++)
    {
        vi[i].clear();
    }
}

int low(int *a, int pos, int R)
{
    int l = 1, r = R;
    while (l <= r)
    {
        int mid = (l + r) >> 1;
        if (a[mid] == pos)
        {
            return mid;
        }
        else if (a[mid] > pos)
        {
            r = mid - 1;
        }
        else
        {
            l = mid + 1;
        }
    }

    return 0;
}

void builde(int l, int r, int &o)
{
    o = tol++;
    if (l == r)
    {
        cnt[o] = 0;
        return ;
    }

    int m = (l + r) >> 1;
    builde(l, m, ls[o]);
    builde(m + 1, r, rs[o]);
    cnt[o] = 0;
}

void update(int last, int l, int r, int pos, int &o)
{
    o = tol++;
    cnt[o] = cnt[last] + A[pos];
    ls[o] = ls[last];
    rs[o] = rs[last];
    if (l == r)
    {
        return ;
    }

    int m = (l + r) >> 1;
    if (m >= pos)
    {
        update(ls[last], l, m, pos, ls[o]);
    }
    else
    {
        update(rs[last], m + 1, r, pos, rs[o]);
    }
}

ll query(int ll, int rr, int l, int r)
{
    if (ql <= l && qr >= r)
    {
        return cnt[rr] - cnt[ll];
    }

    int m = (l + r) >> 1;
    long long ans = 0;
    if (ql <= m)
    {
        ans += query(ls[ll], ls[rr], l, m);
    }
    if (qr > m)
    {
        ans += query(rs[ll], rs[rr], m + 1, r);
    }

    return ans;
}

int LCA(int p, int q)
{
    if (d[p] < d[q])
    {
        swap(p, q);
    }
    int log;
    for (log = 1; (1 << log) <= d[p]; log++) ;
    log--;
    for (int i = log; i >= 0; i--)
    {
        if (d[p] - (1 << i) >= d[q])
        {
            p = anc[p][i];
        }
    }
    if (p == q)
    {
        return p;
    }

    for (int i = log; i >= 0; i--)
    {
        if (anc[p][i] != -1 && anc[p][i] != anc[q][i])
        {
            p = anc[p][i];
            q = anc[q][i];
        }
    }

    return anc[p][0];
}

void preprocess(int n)
{
    for (int i = 1; i <= n; i++)
    {
        anc[i][0] = pre[i];
        for (int j = 1; (1 << j) < n; j++)
        {
            anc[i][j] = -1;
        }
    }
    for (int j = 1; (1 << j) <= n; j++)
    {
        for (int i = 1; i <= n; i++)
        {
            if (anc[i][j - 1] != -1)
            {
                int a = anc[i][j - 1];
                anc[i][j] = anc[a][j - 1];
            }
        }
    }
}

void dfs(int u, int p, int m)
{
    d[u] = d[p] + 1;
    pre[u] = p;
    id[u] = tot++;
    int pos = low(A, B[u], m);
    update(T[id[p]], 1, m, pos, T[id[u]]);
    for (int i = 0; i < vi[u].size(); i++)
    {
        int v = vi[u][i];
        if (v == p)
        {
            continue;
        }

        dfs(v, u, m);
    }
}

int low1(int *a, int pos, int R)
{
    int l = 1, r = R;
    while (l < r)
    {
        int m = (l + r) >> 1;
        if (a[m] >= pos)
        {
            r = m;
        }
        else
        {
            l = m + 1;
        }
    }
    return l;
}

int n, q;

int main()
{
    while (~scanf("%d%d", &n, &q))
    {
        init(n);

        for (int i = 1; i <= n; i++)
        {
            scanf("%d", &A[i]);
            B[i] = A[i];
        }
        sort(A + 1, A + 1 + n);

        int m = 1;
        for (int i = 2; i <= n; i++)
        {
            if (A[i] != A[i - 1])
            {
                A[++m] = A[i];
            }
        }

        int u, v;
        for (int i = 1; i < n; i++)
        {
            scanf("%d%d", &u, &v);
            vi[u].push_back(v);
            vi[v].push_back(u);
        }

        builde(1, m, T[0]);
        d[0] = 0;
        dfs(1, 0, m);
        preprocess(n);

        ll ans;
        int s, t, a, b;
        for (int i = 1; i <= q; i++)
        {
            scanf("%d%d%d%d", &s, &t, &a, &b);

            ql = low1(A, a, m);
            qr = low1(A, b, m);
            if (A[qr] > b)
            {
                qr--;
            }
            int o = LCA(s, t);
            ans = 0;
            ans += query(T[id[pre[o]]], T[id[s]], 1, m);
            ans += query(T[id[o]], T[id[t]], 1, m);

            if (i == 1)
            {
                printf("%lld", ans);
            }
            else
            {
                printf(" %lld", ans);
            }
        }
        putchar(10);
    }

    return 0;
}