题目描述

对于序列A,它的逆序对数定义为满足i<j,且Ai>Aj的数对(i,j)的个数。给1到n的一个排列,按照某种顺序依次删除m个元素,你的任务是在每次删除一个元素之前统计整个序列的逆序对数。

输入输出格式

输入格式:

 

输入第一行包含两个整数n和m,即初始元素的个数和删除的元素个数。以下n行每行包含一个1到n之间的正整数,即初始排列。以下m行每行一个正整数,依次为每次删除的元素。

 

输出格式:

 

输出包含m行,依次为删除每个元素之前,逆序对的个数。

 

输入输出样例

输入样例#1: 复制

5 4
1
5
3
4
2
5
1
4
2

输出样例#1: 复制

5
2
2
1

样例解释
(1,5,3,4,2)(1,3,4,2)(3,4,2)(3,2)(3)。

说明

N<=100000 M<=50000

 

看了题解思路,乱搞居然搞出来了   (常数很大吧┭┮﹏┭┮)

#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int maxn=1e5+10;
int Tree[maxn+10];
inline int lowbit(int x)
{
    return (x&-x);
}
void add(int x,int value)
{
    for(int i=x;i<maxn;i+=lowbit(i))
    {
        Tree[i]+=value;
    }
}
int get(int x)
{
    int sum=0;
    for(int i=x;i;i-=lowbit(i))
    {
        sum+=Tree[i];
    }
    return sum;
}
struct node{
int pos,val,del;
int ans;
};
node p[maxn];
bool cmpval(const node &a,const node &b)
{
    return a.val<b.val;
}
bool cmppos(const node &a,const node &b)
{
    return a.pos<b.pos;
}
bool cmppos2(const node &a,const node &b)
{
    return a.pos>b.pos;
}
bool cmpdel(const node &a,const node &b)
{
    return a.del>b.del;
}
void cdq(int l,int r)
{
    if(l==r)
        return ;
    int mid=(l+r)>>1;
    cdq(l,mid);
    cdq(mid+1,r);
    sort(p+l,p+mid+1,cmppos);
    sort(p+mid+1,p+r+1,cmppos);
    int i=l,j=mid+1;
    for(;j<=r;j++)
    {
        while(p[i].pos<p[j].pos&&i<=mid)
        {
            add(p[i].val,1);
            i++;
        }
        p[j].ans+=(get(maxn)-get(p[j].val-1));
    }
    i--;
    for(;i>=l;i--)
    {
        add(p[i].val,-1);
    }

    sort(p+l,p+mid+1,cmppos2);
    sort(p+mid+1,p+r+1,cmppos2);
    i=l,j=mid+1;
    for(;j<=r;j++)
    {
        while(p[i].pos>p[j].pos&&i<=mid)
        {
            add(p[i].val,1);
            i++;
        }
        p[j].ans+=(get(p[j].val-1));
    }
    i--;
    for(;i>=l;i--)
    {
        add(p[i].val,-1);
    }

}
int main()
{
    int n,m,x;
    scanf("%d%d",&n,&m);
    ll ans=0;
    for(int i=1;i<=n;i++)
    {
        p[i].pos=i;
        scanf("%d",&p[i].val);
        ans+=get(maxn)-get(p[i].val-1);
        add(p[i].val,1);

        p[i].del=n+1;
        p[i].ans=0;
    }
    //printf("%lld\n",ans);
    memset(Tree,0,sizeof(Tree));
    sort(p+1,p+1+n,cmpval);
    for(int i=1;i<=m;i++)
    {
        scanf("%d",&x);
        p[x].del=i;
    }
    sort(p+1,p+1+n,cmpdel);
    cdq(1,n);
    sort(p+1,p+1+n,cmpdel);
    int cnt=1;
    printf("%lld\n",ans);
    for(int i=n;i>=1;i--)
    {
        if(p[i].del!=n+1)
        {
            cnt++;
            ans-=p[i].ans;
            printf("%lld\n",ans);
        }
        if(cnt==m)
            break;
    }

    return 0;
}