题目:https://ac.nowcoder.com/acm/contest/83372/J
vp的时候就卡在如何求当前区间内,已经被吃掉的甜甜圈有几个。赛后学了学树状数组,我们可以让每个点为1代表当前位置有甜甜圈,然后初始化树状数组c[i],然后在每次选取最大值过程中,更新c[i]单点也就是nowi为-1,然后树状数组可以快速求出当前区间1的个数,就代表需要移动的个数,循环n1 + n2次即可,时间复杂度为nlogn。
#include <bits/stdc++.h>
#define int long long
using namespace std;
#define endl '\n'
typedef long long ll;
const int N = 2e5 + 10;
const int mod = 998244353;
int n,c[N];        //c为树状数组

int a[N];   //存储甜甜圈甜度
map<int,int> mp;    //存储甜甜圈位置
priority_queue<int,vector<int>,less<>> q;   //存储当前应当吃的甜甜圈,降序排列

int lowbit(int i){      //求c[x]的区间长度
    return (-i) & i;
}

void add(int i,int z){  //点更新,a[i]加上z
    for (; i <= n ; i += lowbit(i)) {    //更新所有后继,祖先
        c[i] += z;
    }
}

int sum(int i){     //求前缀和,a1+a2+a3
    int sum = 0;
    for (; i > 0; i -= lowbit(i)) { //累加所有前驱
        sum += c[i];
    }
    return sum;
}

int sum(int i,int j){       //求区间和
    return sum(j) - sum(i-1);
}

void solve() {
    int n1,n2;
    cin>>n1>>n2;
    n = n1 + n2;
    for (int i = n1; i >= 1; --i) {
        cin>>a[i];
        add(i,1);
        q.push(a[i]);
        mp[a[i]] = i;
    }
    for (int i = 1; i <= n2; ++i) {
        cin>>a[n1 + i];
        add(i + n1,1);
        q.push(a[i + n1]);
        mp[a[i + n1]] = n1 + i;
    }
    int nowx = q.top(); //当前需要的甜度
    int nowi = mp[q.top()]; //当前甜度对应数组的位置
    int ans = 0;    //记录移动次数
    int l = n1,r = n1 + 1;  //分别为左队列和右队列的队头位置
    for (int i = 1; i <= n; ++i) {
        nowi = mp[q.top()];
        q.pop();
        if(nowi == l || nowi == r){
            add(nowi,-1);
            if(nowi == l){
                l = nowi - 1;
            } else {
                r = nowi + 1;
            }
        }else if(nowi < l){
            ans += sum(l) - sum(nowi);  //移动的元素个数,其实就是nowi到l的1的个数
            l = nowi - 1;
            r = nowi + 1;
            add(nowi,-1);
        } else{
            ans += sum(nowi - 1) - sum(r - 1);  //同理
            r = nowi + 1;
            l = nowi - 1;
            add(nowi,-1);
        }
    }
    cout<<ans<<endl;
}

signed main() {
    int t = 1;
    //cin >> t;
    while (t--) {
        solve();
    }
    return 0;
}