归并排序做法

什么是归并排序呢?用一张图来说明:

图片说明
(本图引用自浙江大学数据结构MOOC

归并排序可以理解为:将两个有序的序列合并成一个有序的序列。
我们递归地执行,直到区间分割到单个元素,然后再递归回去,去执行有序序列的合并,就完成了归并排序。

当出现a[x] > a[y]的情况时,出现逆序

#include <bits/stdc++.h>
#define sc(x) scanf("%d", &x)
const int N = 1e5 + 7;
typedef long long ll;
int n, a[N], b[N];
ll ans;
void mergeSort(int L, int R) {
    if (L == R) return;
    int mid = (L + R) >> 1;
    mergeSort(L, mid);
    mergeSort(mid + 1, R);
    // 此时两个区间内已经有序
    int x = L, y = mid + 1;  // 指向两个有序子列的下标
    for (int i = 1; i <= R - L + 1; ++i) {
        if (x <= mid && y <= R) {  // 两个有序子列合并成一个有序列
            if (a[x] > a[y]) {
                ans += mid - x + 1;  // 统计答案
                // 注意这里用mid来减而非它正确的位置
                b[i] = a[y++];
            } else
                b[i] = a[x++];
        } else {
            if (x <= mid) b[i] = a[x++];
            if (y <= R) b[i] = a[y++];
        }
    }
    for (int i = R - L + 1; i; --i) a[R--] = b[i];  // 覆盖回去
}
int main() {
    sc(n);
    for (int i = 1; i <= n; i++) sc(a[i]);
    mergeSort(1, n);
    printf("%lld\n", ans);
    return 0;
}

树状数组做法

考虑维护一个树状数组。
在每插入一个值之前,查询此时有多少个比它大的值。

#include <bits/stdc++.h>
#define rep(i, l, r) for (int i = (l); i <= (r); ++i)
using namespace std;
typedef long long ll;
const int N = 1e5 + 7;
#define lowbit(x) ((x) & (-x))
int tree[N];
inline void upd(int i) {
    while (i < N) {
        ++tree[i];
        i += lowbit(i);
    }
}
inline ll query(int i) {
    ll ans = 0;
    while (i) {
        ans += tree[i];
        i -= lowbit(i);
    }
    return ans;
}
inline ll query(int a, int b) { return query(b) - query(a - 1); }
int main() {
    ios::sync_with_stdio(false), cin.tie(0), cout.tie(0);
    ll n, x, ans = 0;
    cin >> n;
    rep(i, 1, n) {
        cin >> x;
        ans += query(x + 1, N - 1);
        upd(x);
    }
    cout << ans << '\n';
    return 0;
}

线段树做法

线段树做法和BIT是一个道理

// https://blog.nowcoder.net/n/fd7f7193086c4a4b97e0dee0b031ac38
#include <bits/stdc++.h>
using namespace std;
const int N = 1e5 + 7;
int W[N << 2];
inline void insert(int now, int l, int r, int x) {
    ++W[now];
    if (l == r) return;
    int mid = (l + r) >> 1;
    if (x <= mid)
        insert(now << 1, l, mid, x);
    else
        insert(now << 1 | 1, mid + 1, r, x);
}
inline int find(int now, int l, int r, int lc, int rc) {
    if (lc <= l && r <= rc) return W[now];
    int mid = (l + r) >> 1, res = 0;
    if (lc <= mid) res += find(now << 1, l, mid, lc, rc);
    if (rc > mid) res += find(now << 1 | 1, mid + 1, r, lc, rc);
    return res;
}
int main() {
    int n;
    long long res = 0;
    scanf("%d", &n);
    for (int i = 1; i <= n; ++i) {
        int x;
        scanf("%d", &x);
        ++x;
        if (x != N) res += find(1, 1, N, x + 1, N);
        insert(1, 1, N, x);
    }
    printf("%lld", res);
    return 0;
}