第一眼看到这道题,只会n2的暴力,后来去膜拜完大佬的代码。终于会了树状数组加归并求本题的方法。
首先,我们用ans代表最终的答案,显然ans最大值n*(n-1)/2。既然不知道怎么求ans,那么我们换一种思路,可以考虑求不满足条件的数对,从ans最大值中减去。
什么样的数对会不满足条件呢?对于数对(i,j),如果ai>aj,且bi>bj,且ci>cj。那么这个数列求不满足条件,因为j比i绝对优秀。
等等,是不是有点像大家都喜欢的求逆序对呢?一开始,我们把他们按照a升序排列,这样就不用了考虑a的影响了。然后我们按照求逆序对的方法。
在对b进行归并排序的过程中,统计有多少对数对不满足条件,从ans中减去。
这个时候,神奇的数状数组就可以派上用场了。
把1 - n进行二分。
在下面的代码中初始i=x(左端点),j=mid+1。其中ai一定小于aj,如果bi<bj,那么我们统计,在已处理i中,比赛c比当前ci优秀的数量+1,这样,下面的add()函数,只对i(x - mid)进行处理,而不处理j(mid+1 - y),因为我们处理的是在区间合并过程中产生的不满足条件的数。对于(x - mid)和(mid+1 - y)两个单独的区间已经在归并过程处理了。
于是当我们发现bj<bi,统计一次不符合条件的数对。到此时为止(当前bi未考虑),我们统计的这个区间(x~i),所有的i和j都有ai<aj,bi<bj。有多少ci<cj呢?别急,支持单点修改,方便快捷的树状树状都帮我们记好了。代码中的query函数对应查询有多少i比赛c比cj优秀,这样,我们减掉query(cj)即可。
每次一个区间操作完,都要把树状数组清零,不对其他区间的处理产生影响。
下面结合代码做一些简单的注释。
希望对你有所帮助。

#include<iostream>
#include<algorithm>
#include<cstdio>
#include<cstring>
#define lowbit(x) (x&-x)
#define ll long long
#define N 220000
using namespace std;
int n;
ll ans;
int c[N];
struct node{
    int a,b,c;
}q[N],p[N];
bool operator < (node x,node y){
    return x.a<y.a;
}//按a升序排序
inline void r(int &x){
    x=0;
    char ch=getchar();
    while(ch<'0'||ch>'9') ch=getchar();
    while(ch>='0'&&ch<='9') x=(x<<3)+(x<<1)+(ch^48),ch=getchar();
}//快读
inline void add(int x,int val){
    while(x<=n){
        c[x]+=val;
        x+=lowbit(x);
    }
}//这里是树状数组的单点修改
inline int query(int x){
    int ans=0;
    while(x>0){
        ans+=c[x];
        x-=lowbit(x);
    }
    return ans;
}//查询
inline void f(int x,int y){//二分,归并。
    if(x==y) return;
    int mid=(x+y)>>1;
    int k=x,i=x,j=mid+1;
    f(x,mid);
    f(mid+1,y);
    while(i<=mid&&j<=y){
        if(q[i].b<q[j].b) add(q[i].c,1),p[k++]=q[i++];//当前i的a和b一定比后面j都优秀,把没他优秀的所有c都+1操作。用来给后面查询。
        else ans-=query(q[j].c),p[k++]=q[j++];//这个i的c不如这个j优秀,减掉c比他优秀的数量,换句话说就是减掉比j绝对优秀的i的数量。
    }
    while(i<=mid){
        add(q[i].c,1);//同上i的操作
        p[k++]=q[i++];
    }
    while(j<=y){
        ans-=query(q[j].c);//同上j的操作
        p[k++]=q[j++];
    }
    for(i=x;i<=mid;i++) add(q[i].c,-1);//清零
    for(i=x;i<=y;i++) q[i]=p[i];

}
int main()
{
    int i,j;
    r(n);
    for(i=1;i<=n;i++)
        r(q[i].a),r(q[i].b),r(q[i].c);
    sort(q+1,q+1+n);
    ans=(ll)n*(n-1)/2;//ans赋最大值。
    f(1,n);
    printf("%lld",ans);
}