传送门

主席树题解传送门:https://blog.csdn.net/qq_42211531/article/details/90034520

这篇博客也算是填了以前的坑,学会了CDQ分治来做这一道题。

以前做这一道题的时候,思路是对的,但是我不知道维护我想要的信息,学会了CDQ分治,这就是一道三维偏序的题。只是这一道题要求两种情况的三维偏序。

代码:

///#include<bits/stdc++.h>
///#include<unordered_map>
///#include<unordered_set>
#include<iostream>
#include<algorithm>
#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<string>
#include<cmath>
#include<queue>
#include<bitset>
#include<set>
#include<stack>
#include<map>
#include<list>
#include<new>
#include<vector>

#define MT(a, b) memset(a,b,sizeof(a))
#define lowbit(x) (x&(-x))
using namespace std;
typedef long long ll;
const double pai = acos(-1.0);
const double E = 2.718281828459;
const ll mod = 1e9 + 7;
const double esp = 1e-6;
const int INF = 0x3f3f3f3f;
const int maxn = 2e5 + 5;

int n, m, num[maxn], pos[maxn];
int pre[maxn], last[maxn];
int low[maxn];
int q[maxn], ans1[maxn], ans2[maxn];

struct node {
    int pos, x, index;
} op[maxn], op1[maxn], temp[maxn];

void add(int i, int x) {
    while (i <= n)
        low[i] += x, i += lowbit(i);
}

int sum(int i) {
    int all = 0;
    while (i > 0)
        all += low[i], i -= lowbit(i);
    return all;
}

bool cmp1(node a, node b) {
    return a.pos > b.pos;
}

bool cmp2(node a, node b) {
    return a.pos < b.pos;
}

void cdq1(int l, int r) {
    if (l == r)
        return;
    int mid = (l + r) >> 1;
    cdq1(l, mid), cdq1(mid + 1, r);
    int ql = l, qr = mid + 1;
    for (int i = l; i <= mid; i++)
        add(op[i].x, 1);
    for (int i = l; i <= r; i++) {
        if ((ql <= mid && cmp1(op[ql], op[qr])) || qr > r) {
            add(op[ql].x, -1);
            temp[i] = op[ql++];
        } else {
            ans1[op[qr].index] += sum(n) - sum(op[qr].x);
            temp[i] = op[qr++];
        }
    }
    for (int i = l; i <= r; i++)
        op[i] = temp[i];
}

void cdq2(int l, int r) {
    if (l == r)
        return;
    int mid = (l + r) >> 1;
    cdq2(l, mid), cdq2(mid + 1, r);
    int ql = l, qr = mid + 1;
    for (int i = l; i <= mid; i++)
        add(op1[i].x, 1);
    for (int i = l; i <= r; i++) {
        if ((ql <= mid && cmp2(op1[ql], op1[qr])) || qr > r) {
            add(op1[ql].x, -1);
            temp[i] = op1[ql++];
        } else {
            ans2[op1[qr].index] += sum(op1[qr].x );
            temp[i] = op1[qr++];
        }
    }
    for (int i = l; i <= r; i++)
        op1[i] = temp[i];
}

int main() {
    int x;
    ll ans = 0;
    scanf("%d %d", &n, &m);
    for (int i = 1; i <= n; i++) {
        scanf("%d", &num[i]);
        pos[num[i]] = i;
        pre[i] = i - 1 - sum(num[i] - 1);   ///前面比它大的
        ans += pre[i];
        last[i] = num[i] + pre[i] - i;      ///后面比它小的
        add(num[i], 1);
    }
    for (int i = 1; i <= m; i++) {
        scanf("%d", &x);
        q[i] = x;
        op[i] = node{pos[x], x, i};
        op1[i] = op[i];
    }
    memset(low, 0, sizeof(low));
    cdq1(1, m);
    memset(low, 0, sizeof(low));
    cdq2(1, m);
    for (int i = 1; i <= m; i++) {
        printf("%lld\n", ans);
        ///减去他对原数组的贡献,加上多减去的部分。
        ans = ans - pre[pos[q[i]]] - last[pos[q[i]]] + ans1[i] + ans2[i];
    }
    return 0;
}