题意

给定\(n\)个数,每次交换两个数,输出交换后的逆序数。

分析

  • 交换两个数只会影响到对应区间内的逆序数,具体为减少区间\([l+1,r-1]\)中比\(a[r]\)大的数的个数,增加比\(a[r]\)大的数的个数,减少比大的数的个数,\(a[l]\)增加比\(a[l]\)小的数的个数。
  • 转化为单点修改+查询区间值域个数,树套树。
  • 思路不难想,写完调了一年,注意几个点
    • 外层bit大小是的是序列长度n,不是离散化后的值域ns。
    • 数据不保证\(l<=r\)
    • 注意相同元素。
    • 最后要判断\(a[l]\)\(a[r]\)的大小关系,除去相等。

代码

#include <bits/stdc++.h>
using namespace std;
const int N=2e4+50;
int n,ns,m,a[N],l,r,tr[N*20],x[N],y[N],c1,c2;
struct Orz{
    vector<int> a;
    void init(){
        a.clear();
    }
    int siz(){
        return a.size();
    }
    void add(int x){
        a.push_back(x);
    }
    void work(){
        sort(a.begin(),a.end());
        a.erase(unique(a.begin(),a.end()),a.end());
    }
    int idx(int v){
        return lower_bound(a.begin(),a.end(),v)-a.begin()+1;
    }
    int val(int i){
        return a[i-1];
    }
}orz;
struct HJT{
#define mid (l+r)/2
    int tot,sum[N*200],ls[N*200],rs[N*200];
    void update(int &x,int l,int r,int v,int add){
        if(!x){
            x=++tot;
        }
        sum[x]+=add;
        if(l<r){
            if(v<=mid){
                update(ls[x],l,mid,v,add);
            }else{
                update(rs[x],mid+1,r,v,add);
            }
        }
    }
    int query(int l,int r,int k){
        if(k==0){
            return 0;
        }
        if(r<=k){
            int ans=0;
            for(int i=1;i<=c1;i++){
                ans-=sum[x[i]];
            }
            for(int i=1;i<=c2;i++){
                ans+=sum[y[i]];
            }
            return ans;
        }
        if(k<=mid){
            for(int i=1;i<=c1;i++){
                x[i]=ls[x[i]];
            }
            for(int i=1;i<=c2;i++){
                y[i]=ls[y[i]];
            }
            return query(l,mid,k);
        }else{
            int ans=0;
            for(int i=1;i<=c1;i++){
                ans-=sum[ls[x[i]]];
            }
            for(int i=1;i<=c2;i++){
                ans+=sum[ls[y[i]]];
            }
            for(int i=1;i<=c1;i++){
                x[i]=rs[x[i]];
            }
            for(int i=1;i<=c2;i++){
                y[i]=rs[y[i]];
            }
            return ans+query(mid+1,r,k);
        }
    }
}ac;
struct BIT{
    int lowbit(int x){
        return x&(-x);
    }
    void modify(int i,int x){
        int k=a[i];
        while(i<=n){
            ac.update(tr[i],1,ns,k,x);
            i+=lowbit(i);
        }
    }
    int query(int l,int r,int xi,int yi){
        if(xi>yi){
            return 0;
        }
        c1=c2=0;
        for(int i=l-1;i;i-=lowbit(i)){
            x[++c1]=tr[i];
        }
        for(int i=r;i;i-=lowbit(i)){
            y[++c2]=tr[i];
        }
        int R=ac.query(1,ns,yi);
        c1=c2=0;
        for(int i=l-1;i;i-=lowbit(i)){
            x[++c1]=tr[i];
        }
        for(int i=r;i;i-=lowbit(i)){
            y[++c2]=tr[i];
        }
        int L=ac.query(1,ns,xi-1);
        return R-L;
    }
}bit;
int main(){
    // freopen("in.txt","r",stdin);
    scanf("%d",&n);
    orz.init();
    for(int i=1;i<=n;i++){
        scanf("%d",&a[i]);
        orz.add(a[i]);
    }
    orz.work();
    ns=orz.siz();
    int ans=0;
    for(int i=1;i<=n;i++){
        a[i]=orz.idx(a[i]);
        bit.modify(i,1);
        ans+=bit.query(1,i,a[i]+1,ns);
    }
    printf("%d\n",ans);
    scanf("%d",&m);
    for(int i=1;i<=m;i++){
        scanf("%d%d",&l,&r);
        if(l>r){
            swap(l,r);
        }
        if(l==r){
            printf("%d\n",ans);
            continue;
        }
        if(r-l>=2){
            int ta=bit.query(l+1,r-1,a[r]+1,ns);
            int tb=bit.query(l+1,r-1,a[l]+1,ns);
            int tc=bit.query(l+1,r-1,1,a[r]-1);
            int td=bit.query(l+1,r-1,1,a[l]-1);
            ans-=ta;
            ans+=tc;
            ans+=tb;
            ans-=td;
        }
        if(a[l]<a[r]){
            ans++;
        }else if(a[l]>a[r]){
            ans--;
        }
        bit.modify(l,-1);
        bit.modify(r,-1);
        swap(a[l],a[r]);
        bit.modify(l,1);
        bit.modify(r,1);
        printf("%d\n",ans);
    }
    return 0;
}