题意
给定\(n\)个数,每次交换两个数,输出交换后的逆序数。
分析
- 交换两个数只会影响到对应区间内的逆序数,具体为减少区间\([l+1,r-1]\)中比\(a[r]\)大的数的个数,增加比\(a[r]\)大的数的个数,减少比大的数的个数,\(a[l]\)增加比\(a[l]\)小的数的个数。
- 转化为单点修改+查询区间值域个数,树套树。
- 思路不难想,写完调了一年,注意几个点
- 外层bit大小是的是序列长度n,不是离散化后的值域ns。
- 数据不保证\(l<=r\)。
- 注意相同元素。
- 最后要判断\(a[l]\)和\(a[r]\)的大小关系,除去相等。
代码
#include <bits/stdc++.h>
using namespace std;
const int N=2e4+50;
int n,ns,m,a[N],l,r,tr[N*20],x[N],y[N],c1,c2;
struct Orz{
vector<int> a;
void init(){
a.clear();
}
int siz(){
return a.size();
}
void add(int x){
a.push_back(x);
}
void work(){
sort(a.begin(),a.end());
a.erase(unique(a.begin(),a.end()),a.end());
}
int idx(int v){
return lower_bound(a.begin(),a.end(),v)-a.begin()+1;
}
int val(int i){
return a[i-1];
}
}orz;
struct HJT{
#define mid (l+r)/2
int tot,sum[N*200],ls[N*200],rs[N*200];
void update(int &x,int l,int r,int v,int add){
if(!x){
x=++tot;
}
sum[x]+=add;
if(l<r){
if(v<=mid){
update(ls[x],l,mid,v,add);
}else{
update(rs[x],mid+1,r,v,add);
}
}
}
int query(int l,int r,int k){
if(k==0){
return 0;
}
if(r<=k){
int ans=0;
for(int i=1;i<=c1;i++){
ans-=sum[x[i]];
}
for(int i=1;i<=c2;i++){
ans+=sum[y[i]];
}
return ans;
}
if(k<=mid){
for(int i=1;i<=c1;i++){
x[i]=ls[x[i]];
}
for(int i=1;i<=c2;i++){
y[i]=ls[y[i]];
}
return query(l,mid,k);
}else{
int ans=0;
for(int i=1;i<=c1;i++){
ans-=sum[ls[x[i]]];
}
for(int i=1;i<=c2;i++){
ans+=sum[ls[y[i]]];
}
for(int i=1;i<=c1;i++){
x[i]=rs[x[i]];
}
for(int i=1;i<=c2;i++){
y[i]=rs[y[i]];
}
return ans+query(mid+1,r,k);
}
}
}ac;
struct BIT{
int lowbit(int x){
return x&(-x);
}
void modify(int i,int x){
int k=a[i];
while(i<=n){
ac.update(tr[i],1,ns,k,x);
i+=lowbit(i);
}
}
int query(int l,int r,int xi,int yi){
if(xi>yi){
return 0;
}
c1=c2=0;
for(int i=l-1;i;i-=lowbit(i)){
x[++c1]=tr[i];
}
for(int i=r;i;i-=lowbit(i)){
y[++c2]=tr[i];
}
int R=ac.query(1,ns,yi);
c1=c2=0;
for(int i=l-1;i;i-=lowbit(i)){
x[++c1]=tr[i];
}
for(int i=r;i;i-=lowbit(i)){
y[++c2]=tr[i];
}
int L=ac.query(1,ns,xi-1);
return R-L;
}
}bit;
int main(){
// freopen("in.txt","r",stdin);
scanf("%d",&n);
orz.init();
for(int i=1;i<=n;i++){
scanf("%d",&a[i]);
orz.add(a[i]);
}
orz.work();
ns=orz.siz();
int ans=0;
for(int i=1;i<=n;i++){
a[i]=orz.idx(a[i]);
bit.modify(i,1);
ans+=bit.query(1,i,a[i]+1,ns);
}
printf("%d\n",ans);
scanf("%d",&m);
for(int i=1;i<=m;i++){
scanf("%d%d",&l,&r);
if(l>r){
swap(l,r);
}
if(l==r){
printf("%d\n",ans);
continue;
}
if(r-l>=2){
int ta=bit.query(l+1,r-1,a[r]+1,ns);
int tb=bit.query(l+1,r-1,a[l]+1,ns);
int tc=bit.query(l+1,r-1,1,a[r]-1);
int td=bit.query(l+1,r-1,1,a[l]-1);
ans-=ta;
ans+=tc;
ans+=tb;
ans-=td;
}
if(a[l]<a[r]){
ans++;
}else if(a[l]>a[r]){
ans--;
}
bit.modify(l,-1);
bit.modify(r,-1);
swap(a[l],a[r]);
bit.modify(l,1);
bit.modify(r,1);
printf("%d\n",ans);
}
return 0;
}