归并排序做法
什么是归并排序呢?用一张图来说明:
(本图引用自浙江大学数据结构MOOC)
归并排序可以理解为:将两个有序的序列合并成一个有序的序列。
我们递归地执行,直到区间分割到单个元素,然后再递归回去,去执行有序序列的合并,就完成了归并排序。
当出现a[x] > a[y]的情况时,出现逆序
#include <bits/stdc++.h>
#define sc(x) scanf("%d", &x)
const int N = 1e5 + 7;
typedef long long ll;
int n, a[N], b[N];
ll ans;
void mergeSort(int L, int R) {
if (L == R) return;
int mid = (L + R) >> 1;
mergeSort(L, mid);
mergeSort(mid + 1, R);
// 此时两个区间内已经有序
int x = L, y = mid + 1; // 指向两个有序子列的下标
for (int i = 1; i <= R - L + 1; ++i) {
if (x <= mid && y <= R) { // 两个有序子列合并成一个有序列
if (a[x] > a[y]) {
ans += mid - x + 1; // 统计答案
// 注意这里用mid来减而非它正确的位置
b[i] = a[y++];
} else
b[i] = a[x++];
} else {
if (x <= mid) b[i] = a[x++];
if (y <= R) b[i] = a[y++];
}
}
for (int i = R - L + 1; i; --i) a[R--] = b[i]; // 覆盖回去
}
int main() {
sc(n);
for (int i = 1; i <= n; i++) sc(a[i]);
mergeSort(1, n);
printf("%lld\n", ans);
return 0;
} 树状数组做法
考虑维护一个树状数组。
在每插入一个值之前,查询此时有多少个比它大的值。
#include <bits/stdc++.h>
#define rep(i, l, r) for (int i = (l); i <= (r); ++i)
using namespace std;
typedef long long ll;
const int N = 1e5 + 7;
#define lowbit(x) ((x) & (-x))
int tree[N];
inline void upd(int i) {
while (i < N) {
++tree[i];
i += lowbit(i);
}
}
inline ll query(int i) {
ll ans = 0;
while (i) {
ans += tree[i];
i -= lowbit(i);
}
return ans;
}
inline ll query(int a, int b) { return query(b) - query(a - 1); }
int main() {
ios::sync_with_stdio(false), cin.tie(0), cout.tie(0);
ll n, x, ans = 0;
cin >> n;
rep(i, 1, n) {
cin >> x;
ans += query(x + 1, N - 1);
upd(x);
}
cout << ans << '\n';
return 0;
} 线段树做法
线段树做法和BIT是一个道理
// https://blog.nowcoder.net/n/fd7f7193086c4a4b97e0dee0b031ac38
#include <bits/stdc++.h>
using namespace std;
const int N = 1e5 + 7;
int W[N << 2];
inline void insert(int now, int l, int r, int x) {
++W[now];
if (l == r) return;
int mid = (l + r) >> 1;
if (x <= mid)
insert(now << 1, l, mid, x);
else
insert(now << 1 | 1, mid + 1, r, x);
}
inline int find(int now, int l, int r, int lc, int rc) {
if (lc <= l && r <= rc) return W[now];
int mid = (l + r) >> 1, res = 0;
if (lc <= mid) res += find(now << 1, l, mid, lc, rc);
if (rc > mid) res += find(now << 1 | 1, mid + 1, r, lc, rc);
return res;
}
int main() {
int n;
long long res = 0;
scanf("%d", &n);
for (int i = 1; i <= n; ++i) {
int x;
scanf("%d", &x);
++x;
if (x != N) res += find(1, 1, N, x + 1, N);
insert(1, 1, N, x);
}
printf("%lld", res);
return 0;
} 
京公网安备 11010502036488号