前言
我看到题解区清一色的离散化+树状数组,所以我就写了一点不一样的,其实逆序对问题用归并排序也是一个很好的解法
题目解释
题目描述中的个子区间指的是n个一个数的子区间,n-1个连续两个数的子区间,n-2个连续3个数的子区间,……
题目的坑点
本题的数据规模很大(n<=1000000(一百万)),所以,不仅是时间比较紧,答案也超过了longlong(一定要注意,我因为这个浪费了好多时间),我比较懒,用了int128,也可以自己手写高精度
解题思路
归并排序!!!
为什么能用归并排序解决逆序对问题?
归并排序的基本原理我们就不说了,不会的可以去网上搜索,网上有很多。
我们知道归并排序在合并的过程中,会将左右区间的左端点进行比较,然后将小的那个拿到临时储存的数组中去,在这个过程中,如果我们拿的是右区间的数,那么一定有:被拿的那个右区间的数与所有左区间剩余的数构成逆序对!,我们统计这个过程的所有贡献,就能得到最后的答案
每组逆序对贡献的计算:
假设某一个逆序对两个数位于i,j,那么这两个逆序对一共被(i)(n-j+1)个子区间包含,所以产生的贡献为i(n-j+1)
此思路代码实现:
struct P{
int val;
int pos;
}a[mx];
P b[mx];
void mergesort(int l,int r){
if(l==r)return;
int mid=(l+r)>>1;
mergesort(l,mid);
mergesort(mid+1,r);
int i=l,j=mid+1,k=l;
while(i<=mid&&j<=r){
if(a[i].val<=a[j].val){
b[k]=a[i];
k++;
i++;
}
else{
b[k]=a[j];
for(int o=i;o<=mid;o++)ans+=a[o].pos*(n-a[j].pos+1);
k++;
j++;
}
}
while(i<=mid)b[k++]=a[i++];
while(j<=r)b[k++]=a[j++];
for(int o=l;o<=r;o++)a[o]=b[o];
}该思路的时间复杂度
正常的归并排序是O(nlongn),但是,因为我们在计算贡献的时候从i循环到mid,所以有可能达到O(n^2logn),这思路显然需要优化
优化方法
关于那个循环,我们发现, 这部分可以用预处理过的前缀和代替,即
,预处理是O(n)的,但是合并的复杂度也是O(n),而且预处理与合并是并列的,所以,最终的复杂度为O(nlogn)完美解决
最终代码
struct P{
int val;
int pos;
}a[mx];
P b[mx];
__int128 sum[mx];//前缀和数组
void mergesort(int l,int r){
if(l==r)return;
int mid=(l+r)>>1;
mergesort(l,mid);
mergesort(mid+1,r);
int i=l,j=mid+1,k=l;
sum[l-1]=0;
for(int o=l;o<=mid;o++){
sum[o]=sum[o-1]+a[o].pos;
}
while(i<=mid&&j<=r){
if(a[i].val<=a[j].val){
b[k]=a[i];
k++;
i++;
}
else{
b[k]=a[j];
ans+=(sum[mid]-sum[i-1])*(n-a[j].pos+1);
k++;
j++;
}
}
while(i<=mid)b[k++]=a[i++];
while(j<=r)b[k++]=a[j++];
for(int o=l;o<=r;o++)a[o]=b[o];
}完整代码
#include<cstdio>
using namespace std;
inline int Read(){
int x=0;
char c=getchar();
while(c>'9'||c<'0')c=getchar();
while(c>='0'&&c<='9')x=x*10+c-'0',c=getchar();
return x;
}
void print(__int128 x){//手写输出int128
if(x>9)print(x/10);
putchar(x%10+'0');
}
const int mx=1000001;
__int128 ans=0;//我竟然因为这个原因浪费了那么久时间
int n;
struct P{
int val;
int pos;
}a[mx];
P b[mx];
__int128 sum[mx];//前缀和数组
void mergesort(int l,int r){
if(l==r)return;
int mid=(l+r)>>1;
mergesort(l,mid);
mergesort(mid+1,r);
int i=l,j=mid+1,k=l;
sum[l-1]=0;
for(int o=l;o<=mid;o++){
sum[o]=sum[o-1]+a[o].pos;
}
while(i<=mid&&j<=r){
if(a[i].val<=a[j].val){
b[k]=a[i];
k++;
i++;
}
else{
b[k]=a[j];
ans+=(sum[mid]-sum[i-1])*(n-a[j].pos+1);
k++;
j++;
}
}
while(i<=mid)b[k++]=a[i++];
while(j<=r)b[k++]=a[j++];
for(int o=l;o<=r;o++)a[o]=b[o];
}
int main(){
n=Read();
for(int i=1;i<=n;i++)a[i].val=Read(),a[i].pos=i;
mergesort(1,n);
print(ans);
return 0;
} 


京公网安备 11010502036488号