题目

https://loj.ac/problem/10143

题解

  • 一眼就能看出这题就是找Splay的前驱和后继
  • 通过这题可以体会到为什么要加入一个无穷小的点和一个无穷大的点,加入可以防止出坑

代码

#pragma GCC optimize(2)
#pragma GCC optimize(3,"Ofast","inline")
#include<bits/stdc++.h>
using namespace std;
#define ll long long
const int N = 1e5+7;
const int inf = 0x3f3f3f;
int rt;//根节点
int tot;//节点个数
struct node {
    int fa;//父亲节点
    int ch[2];//子节点
    int val;//权值
    int tag;//标记
    int sz;//子树大小
    int cnt;
}s[N];

struct Splay {
     int get(int x) {return s[s[x].fa].ch[1] == x;}

     void Clear(int x) {
         s[x].fa = s[x].ch[0] = s[x].ch[1] = s[x].sz = s[x].val = s[x].tag =0;
     }

     void maintain(int x){
         s[x].sz = s[s[x].ch[0]].sz + s[s[x].ch[1]].sz + s[x].cnt;
     }

     void Rorate(int x){
         int y = s[x].fa, z = s[y].fa, chk = get(x);

         s[y].ch[chk] = s[x].ch[chk ^ 1];
         s[s[x].ch[chk ^ 1]].fa = y;

         s[y].fa = x;
         s[x].ch[chk ^ 1] = y;

         s[x].fa =z;
         if(z) s[z].ch[s[z].ch[1] == y] = x;

         maintain(y);
         maintain(x);
     }

     void splay(int x,int y){
         for(int f = s[x].fa;f != y;Rorate(x),f=s[x].fa){
             if(s[f].fa != y) Rorate(get(x) == get(f) ? f : x);
         }
         if(y==0) rt = x;
     }

     void ins(int k){
         if(!rt) {
             s[++tot].val = k;
             s[tot].cnt++;
             rt = tot;
             maintain(rt);
             return ;
         }
         int now = rt,f = 0;
         while(true) {
             if(s[now].val == k) {
                 s[now].cnt++;
                 maintain(now);
                 maintain(f);
                 splay(now,0);
                 break;
             }
             f = now;
             now = s[now].ch[s[now].val < k];
             if(now == 0){
                 s[++tot].val = k;
                 s[tot].cnt++;
                 s[tot].fa = f;
                 s[f].ch[s[f].val < k] = tot;
                 maintain(tot);
                 maintain(f);
                 splay(tot,0);
                 break;
             }
         }
     }

     int getPre(){
         int now = s[rt].ch[0];
         while (s[now].ch[1]) now = s[now].ch[1];
         return now;
     }

     int getNxt(){
         int now = s[rt].ch[1];
         while(s[now].ch[0]) now = s[now].ch[0];
         return now;
     }
}st;
int main(){
    int n,x;
    scanf("%d",&n);
    st.ins(1000001);
    st.ins(-1000001);
    scanf("%d",&x);
    st.ins(x),n-=1;
    ll ans = x;
    while(n--){
        scanf("%d",&x);
        st.ins(x);
        if(s[rt].cnt > 1) continue;
        int aa = s[st.getPre()].val,bb = s[st.getNxt()].val,tmp = 1000001;
        if(aa != -1000001){
            tmp = abs(x - aa);
        }
        if(bb != 1000001){
            tmp = min(tmp,abs(x - bb));
        }
        if(tmp != 1000001) ans += tmp;
    }
    printf("%lld\n",ans);
    return 0;
}