前言
我看到题解区清一色的离散化+树状数组,所以我就写了一点不一样的,其实逆序对问题用归并排序也是一个很好的解法
题目解释
题目描述中的个子区间指的是n个一个数的子区间,n-1个连续两个数的子区间,n-2个连续3个数的子区间,……
题目的坑点
本题的数据规模很大(n<=1000000(一百万)),所以,不仅是时间比较紧,答案也超过了longlong(一定要注意,我因为这个浪费了好多时间),我比较懒,用了int128,也可以自己手写高精度
解题思路
归并排序!!!
为什么能用归并排序解决逆序对问题?
归并排序的基本原理我们就不说了,不会的可以去网上搜索,网上有很多。
我们知道归并排序在合并的过程中,会将左右区间的左端点进行比较,然后将小的那个拿到临时储存的数组中去,在这个过程中,如果我们拿的是右区间的数,那么一定有:被拿的那个右区间的数与所有左区间剩余的数构成逆序对!,我们统计这个过程的所有贡献,就能得到最后的答案
每组逆序对贡献的计算:
假设某一个逆序对两个数位于i,j,那么这两个逆序对一共被(i)(n-j+1)个子区间包含,所以产生的贡献为i(n-j+1)
此思路代码实现:
struct P{ int val; int pos; }a[mx]; P b[mx]; void mergesort(int l,int r){ if(l==r)return; int mid=(l+r)>>1; mergesort(l,mid); mergesort(mid+1,r); int i=l,j=mid+1,k=l; while(i<=mid&&j<=r){ if(a[i].val<=a[j].val){ b[k]=a[i]; k++; i++; } else{ b[k]=a[j]; for(int o=i;o<=mid;o++)ans+=a[o].pos*(n-a[j].pos+1); k++; j++; } } while(i<=mid)b[k++]=a[i++]; while(j<=r)b[k++]=a[j++]; for(int o=l;o<=r;o++)a[o]=b[o]; }
该思路的时间复杂度
正常的归并排序是O(nlongn),但是,因为我们在计算贡献的时候从i循环到mid,所以有可能达到O(n^2logn),这思路显然需要优化
优化方法
关于那个循环,我们发现, 这部分可以用预处理过的前缀和代替,即 ,预处理是O(n)的,但是合并的复杂度也是O(n),而且预处理与合并是并列的,所以,最终的复杂度为O(nlogn)完美解决
最终代码
struct P{ int val; int pos; }a[mx]; P b[mx]; __int128 sum[mx];//前缀和数组 void mergesort(int l,int r){ if(l==r)return; int mid=(l+r)>>1; mergesort(l,mid); mergesort(mid+1,r); int i=l,j=mid+1,k=l; sum[l-1]=0; for(int o=l;o<=mid;o++){ sum[o]=sum[o-1]+a[o].pos; } while(i<=mid&&j<=r){ if(a[i].val<=a[j].val){ b[k]=a[i]; k++; i++; } else{ b[k]=a[j]; ans+=(sum[mid]-sum[i-1])*(n-a[j].pos+1); k++; j++; } } while(i<=mid)b[k++]=a[i++]; while(j<=r)b[k++]=a[j++]; for(int o=l;o<=r;o++)a[o]=b[o]; }
完整代码
#include<cstdio> using namespace std; inline int Read(){ int x=0; char c=getchar(); while(c>'9'||c<'0')c=getchar(); while(c>='0'&&c<='9')x=x*10+c-'0',c=getchar(); return x; } void print(__int128 x){//手写输出int128 if(x>9)print(x/10); putchar(x%10+'0'); } const int mx=1000001; __int128 ans=0;//我竟然因为这个原因浪费了那么久时间 int n; struct P{ int val; int pos; }a[mx]; P b[mx]; __int128 sum[mx];//前缀和数组 void mergesort(int l,int r){ if(l==r)return; int mid=(l+r)>>1; mergesort(l,mid); mergesort(mid+1,r); int i=l,j=mid+1,k=l; sum[l-1]=0; for(int o=l;o<=mid;o++){ sum[o]=sum[o-1]+a[o].pos; } while(i<=mid&&j<=r){ if(a[i].val<=a[j].val){ b[k]=a[i]; k++; i++; } else{ b[k]=a[j]; ans+=(sum[mid]-sum[i-1])*(n-a[j].pos+1); k++; j++; } } while(i<=mid)b[k++]=a[i++]; while(j<=r)b[k++]=a[j++]; for(int o=l;o<=r;o++)a[o]=b[o]; } int main(){ n=Read(); for(int i=1;i<=n;i++)a[i].val=Read(),a[i].pos=i; mergesort(1,n); print(ans); return 0; }