前言

我看到题解区清一色的离散化+树状数组,所以我就写了一点不一样的,其实逆序对问题用归并排序也是一个很好的解法

题目解释

题目描述中的图片说明个子区间指的是n个一个数的子区间,n-1个连续两个数的子区间,n-2个连续3个数的子区间,……

题目的坑点

本题的数据规模很大(n<=1000000(一百万)),所以,不仅是时间比较紧,答案也超过了longlong(一定要注意,我因为这个浪费了好多时间),我比较懒,用了int128,也可以自己手写高精度

解题思路

归并排序!!!

为什么能用归并排序解决逆序对问题?

归并排序的基本原理我们就不说了,不会的可以去网上搜索,网上有很多。
我们知道归并排序在合并的过程中,会将左右区间的左端点进行比较,然后将小的那个拿到临时储存的数组中去,在这个过程中,如果我们拿的是右区间的数,那么一定有:被拿的那个右区间的数与所有左区间剩余的数构成逆序对!,我们统计这个过程的所有贡献,就能得到最后的答案

每组逆序对贡献的计算:

假设某一个逆序对两个数位于i,j,那么这两个逆序对一共被(i)(n-j+1)个子区间包含,所以产生的贡献为i(n-j+1)

此思路代码实现:

struct P{
    int val;
    int pos;
}a[mx];
P b[mx];
void mergesort(int l,int r){
    if(l==r)return;
    int mid=(l+r)>>1;
    mergesort(l,mid);
    mergesort(mid+1,r);
    int i=l,j=mid+1,k=l;
    while(i<=mid&&j<=r){
        if(a[i].val<=a[j].val){
            b[k]=a[i];
            k++;
            i++;
        }
        else{
            b[k]=a[j];
            for(int o=i;o<=mid;o++)ans+=a[o].pos*(n-a[j].pos+1);
            k++;
            j++;
        }
    }
    while(i<=mid)b[k++]=a[i++];
    while(j<=r)b[k++]=a[j++];
    for(int o=l;o<=r;o++)a[o]=b[o];
}

该思路的时间复杂度

正常的归并排序是O(nlongn),但是,因为我们在计算贡献的时候从i循环到mid,所以有可能达到O(n^2logn),这思路显然需要优化

优化方法

关于那个循环,我们发现,图片说明 这部分可以用预处理过的前缀和代替,即图片说明 ,预处理是O(n)的,但是合并的复杂度也是O(n),而且预处理与合并是并列的,所以,最终的复杂度为O(nlogn)完美解决

最终代码

struct P{
    int val;
    int pos;
}a[mx];
P b[mx];
__int128 sum[mx];//前缀和数组 
void mergesort(int l,int r){
    if(l==r)return;
    int mid=(l+r)>>1;
    mergesort(l,mid);
    mergesort(mid+1,r);
    int i=l,j=mid+1,k=l;
    sum[l-1]=0;
    for(int o=l;o<=mid;o++){
        sum[o]=sum[o-1]+a[o].pos;
    }
    while(i<=mid&&j<=r){
        if(a[i].val<=a[j].val){
            b[k]=a[i];
            k++;
            i++;
        }
        else{
            b[k]=a[j];
            ans+=(sum[mid]-sum[i-1])*(n-a[j].pos+1);
            k++;
            j++;
        }
    }
    while(i<=mid)b[k++]=a[i++];
    while(j<=r)b[k++]=a[j++];
    for(int o=l;o<=r;o++)a[o]=b[o];
}

完整代码

#include<cstdio>
using namespace std;
inline int Read(){
    int x=0;
    char c=getchar();
    while(c>'9'||c<'0')c=getchar();
    while(c>='0'&&c<='9')x=x*10+c-'0',c=getchar();
    return x;
}
void print(__int128 x){//手写输出int128 
    if(x>9)print(x/10);
    putchar(x%10+'0');
}
const int mx=1000001;
__int128 ans=0;//我竟然因为这个原因浪费了那么久时间
int n;
struct P{
    int val;
    int pos;
}a[mx];
P b[mx];
__int128 sum[mx];//前缀和数组 
void mergesort(int l,int r){
    if(l==r)return;
    int mid=(l+r)>>1;
    mergesort(l,mid);
    mergesort(mid+1,r);
    int i=l,j=mid+1,k=l;
    sum[l-1]=0;
    for(int o=l;o<=mid;o++){
        sum[o]=sum[o-1]+a[o].pos;
    }
    while(i<=mid&&j<=r){
        if(a[i].val<=a[j].val){
            b[k]=a[i];
            k++;
            i++;
        }
        else{
            b[k]=a[j];
            ans+=(sum[mid]-sum[i-1])*(n-a[j].pos+1);
            k++;
            j++;
        }
    }
    while(i<=mid)b[k++]=a[i++];
    while(j<=r)b[k++]=a[j++];
    for(int o=l;o<=r;o++)a[o]=b[o];
}
int main(){
    n=Read();
    for(int i=1;i<=n;i++)a[i].val=Read(),a[i].pos=i;
    mergesort(1,n);
    print(ans);
    return 0;
}