归并排序做法
什么是归并排序呢?用一张图来说明:
(本图引用自浙江大学数据结构MOOC)
归并排序可以理解为:将两个有序的序列合并成一个有序的序列。
我们递归地执行,直到区间分割到单个元素,然后再递归回去,去执行有序序列的合并,就完成了归并排序。
当出现a[x] > a[y]
的情况时,出现逆序
#include <bits/stdc++.h> #define sc(x) scanf("%d", &x) const int N = 1e5 + 7; typedef long long ll; int n, a[N], b[N]; ll ans; void mergeSort(int L, int R) { if (L == R) return; int mid = (L + R) >> 1; mergeSort(L, mid); mergeSort(mid + 1, R); // 此时两个区间内已经有序 int x = L, y = mid + 1; // 指向两个有序子列的下标 for (int i = 1; i <= R - L + 1; ++i) { if (x <= mid && y <= R) { // 两个有序子列合并成一个有序列 if (a[x] > a[y]) { ans += mid - x + 1; // 统计答案 // 注意这里用mid来减而非它正确的位置 b[i] = a[y++]; } else b[i] = a[x++]; } else { if (x <= mid) b[i] = a[x++]; if (y <= R) b[i] = a[y++]; } } for (int i = R - L + 1; i; --i) a[R--] = b[i]; // 覆盖回去 } int main() { sc(n); for (int i = 1; i <= n; i++) sc(a[i]); mergeSort(1, n); printf("%lld\n", ans); return 0; }
树状数组做法
考虑维护一个树状数组。
在每插入一个值之前,查询此时有多少个比它大的值。
#include <bits/stdc++.h> #define rep(i, l, r) for (int i = (l); i <= (r); ++i) using namespace std; typedef long long ll; const int N = 1e5 + 7; #define lowbit(x) ((x) & (-x)) int tree[N]; inline void upd(int i) { while (i < N) { ++tree[i]; i += lowbit(i); } } inline ll query(int i) { ll ans = 0; while (i) { ans += tree[i]; i -= lowbit(i); } return ans; } inline ll query(int a, int b) { return query(b) - query(a - 1); } int main() { ios::sync_with_stdio(false), cin.tie(0), cout.tie(0); ll n, x, ans = 0; cin >> n; rep(i, 1, n) { cin >> x; ans += query(x + 1, N - 1); upd(x); } cout << ans << '\n'; return 0; }
线段树做法
线段树做法和BIT是一个道理
// https://blog.nowcoder.net/n/fd7f7193086c4a4b97e0dee0b031ac38 #include <bits/stdc++.h> using namespace std; const int N = 1e5 + 7; int W[N << 2]; inline void insert(int now, int l, int r, int x) { ++W[now]; if (l == r) return; int mid = (l + r) >> 1; if (x <= mid) insert(now << 1, l, mid, x); else insert(now << 1 | 1, mid + 1, r, x); } inline int find(int now, int l, int r, int lc, int rc) { if (lc <= l && r <= rc) return W[now]; int mid = (l + r) >> 1, res = 0; if (lc <= mid) res += find(now << 1, l, mid, lc, rc); if (rc > mid) res += find(now << 1 | 1, mid + 1, r, lc, rc); return res; } int main() { int n; long long res = 0; scanf("%d", &n); for (int i = 1; i <= n; ++i) { int x; scanf("%d", &x); ++x; if (x != N) res += find(1, 1, N, x + 1, N); insert(1, 1, N, x); } printf("%lld", res); return 0; }