题目链接
题目描述
给定一个由 个整数组成的数组
,计算其中有多少个三元组
满足
且
。
输入:
- 第一行输入一个整数
(
),代表数组中的元素个数。
- 第二行输入
个整数
(
)。
输出:
- 输出一个整数,表示满足条件的三元组个数。
解题思路
本题是一道复杂的计数问题,由于 的范围较大,必须使用 O(
) 的高效算法。正确的解法是基于容斥原理 (Inclusion-Exclusion Principle) 和 树状数组 (Fenwick Tree)。
直接统计满足 a[i] > a[k] > a[j]
的三元组非常困难。因此,我们采用容斥思想:先统计一个更大、更容易计算的集合,然后减去其中不符合条件的部分。
第一步:计算超集(Inclusion)
我们首先计算所有满足 i < j < k
且 a[i]
是三者中最大值的三元组,即 a[i] > a[j]
且 a[i] > a[k]
。
这个集合包含了我们想要的 a[i] > a[k] > a[j]
,但也包含了我们不想要的 a[i] > a[j] >= a[k]
。
我们可以通过一次从右向左的遍历,并使用树状数组来高效计算这个超集的大小:
- 初始化
ans = 0
和一个空的树状数组bit
。 - 遍历
i
从n-1
到0
: a. 查询bit
中有多少个值小于a[i]
。设这个数量为cnt
。这些值都在i
的右侧。 b. 从这cnt
个元素中任意挑选两个作为j
和k
,都能满足a[i] > a[j]
和a[i] > a[k]
。因此,i
对总数的贡献是组合数。 c. 将
a[i]
的出现次数在bit
中加 1。
第二步:减去多余部分(Exclusion)
现在,我们需要从 ans
中减去所有不符合条件的三元组,即满足 i < j < k
且 a[i] > a[j] >= a[k]
的部分。
我们可以通过固定中间元素 j
来统计这部分的大小。对于每个 j
,其贡献为:
(在 j 左侧大于 a[j] 的元素个数) * (在 j 右侧小于等于 a[j] 的元素个数)
这两部分的数量可以通过两次遍历和树状数组高效计算:
- 预计算:在第一步从右向左遍历时,我们可以顺便计算并存储
(在 j 右侧小于等于 a[j] 的元素个数)
到一个辅助数组b
中。 - 主计算:再进行一次从左向右的遍历(以
j
为当前索引),使用一个新的树状数组bit2
实时计算(在 j 左侧大于 a[j] 的元素个数)
,然后乘以预计算好的b[j]
,从ans
中减去。
最终算法流程:
- 离散化
a
数组。 - 第一遍 (从右向左):
a. 用
bit1
查询i
右侧小于a[i]
的个数cnt
,ans += cnt*(cnt-1)/2
。 b. 用bit1
查询i
右侧小于等于a[i]
的个数,存入b[i]
。 c. 更新bit1
。 - 第二遍 (从左向右):
a. 用
bit2
查询j
左侧大于a[j]
的个数cnt_left_greater
。 b.ans -= cnt_left_greater * b[j]
。 c. 更新bit2
。
代码
#include <iostream>
#include <vector>
#include <algorithm>
#include <map>
using namespace std;
const int MAXN = 200005;
long long bit[MAXN];
int n, m;
void add(int pos, long long val) {
for (; pos <= m; pos += pos & -pos) {
bit[pos] += val;
}
}
long long query(int pos) {
long long res = 0;
for (; pos > 0; pos -= pos & -pos) {
res += bit[pos];
}
return res;
}
void clear_bit() {
for (int i = 0; i <= m; ++i) {
bit[i] = 0;
}
}
int main() {
ios::sync_with_stdio(false);
cin.tie(NULL);
cin >> n;
vector<int> a(n);
vector<int> unique_vals;
for (int i = 0; i < n; ++i) {
cin >> a[i];
unique_vals.push_back(a[i]);
}
sort(unique_vals.begin(), unique_vals.end());
unique_vals.erase(unique(unique_vals.begin(), unique_vals.end()), unique_vals.end());
m = unique_vals.size();
map<int, int> val_to_rank;
for (int i = 0; i < m; ++i) {
val_to_rank[unique_vals[i]] = i + 1;
}
vector<int> b_discrete(n);
for (int i = 0; i < n; ++i) {
b_discrete[i] = val_to_rank[a[i]];
}
long long ans = 0;
vector<long long> right_le(n); // 存储右侧小于等于a[i]的个数
// 第一遍:从右向左,计算超集和辅助数组
clear_bit();
for (int i = n - 1; i >= 0; --i) {
long long less_count = query(b_discrete[i] - 1);
ans += less_count * (less_count - 1) / 2;
right_le[i] = query(b_discrete[i]);
add(b_discrete[i], 1);
}
// 第二遍:从左向右,减去多余部分
clear_bit();
for (int j = 0; j < n; ++j) {
long long left_total = j;
long long left_le = query(b_discrete[j]);
long long left_greater = left_total - left_le;
ans -= left_greater * right_le[j];
add(b_discrete[j], 1);
}
cout << ans << '\n';
return 0;
}
import java.util.*;
public class Main {
static int m;
static long[] bit;
static void add(int pos, long val) {
for (; pos <= m; pos += pos & -pos) {
bit[pos] += val;
}
}
static long query(int pos) {
long res = 0;
for (; pos > 0; pos -= pos & -pos) {
res += bit[pos];
}
return res;
}
static void clearBit() {
Arrays.fill(bit, 0);
}
public static void main(String[] args) {
Scanner sc = new Scanner(System.in);
int n = sc.nextInt();
int[] a = new int[n];
Set<Integer> uniqueSet = new HashSet<>();
for (int i = 0; i < n; i++) {
a[i] = sc.nextInt();
uniqueSet.add(a[i]);
}
List<Integer> sortedUnique = new ArrayList<>(uniqueSet);
Collections.sort(sortedUnique);
m = sortedUnique.size();
Map<Integer, Integer> valToRank = new HashMap<>();
for (int i = 0; i < m; i++) {
valToRank.put(sortedUnique.get(i), i + 1);
}
int[] bDiscrete = new int[n];
for (int i = 0; i < n; i++) {
bDiscrete[i] = valToRank.get(a[i]);
}
bit = new long[m + 1];
long ans = 0;
long[] rightLe = new long[n];
// 第一遍:从右向左
clearBit();
for (int i = n - 1; i >= 0; i--) {
long lessCount = query(bDiscrete[i] - 1);
ans += lessCount * (lessCount - 1) / 2;
rightLe[i] = query(bDiscrete[i]);
add(bDiscrete[i], 1);
}
// 第二遍:从左向右
clearBit();
for (int j = 0; j < n; j++) {
long leftTotal = j;
long leftLe = query(bDiscrete[j]);
long leftGreater = leftTotal - leftLe;
ans -= leftGreater * rightLe[j];
add(bDiscrete[j], 1);
}
System.out.println(ans);
}
}
def solve():
n = int(input())
a = list(map(int, input().split()))
# 离散化
unique_vals = sorted(list(set(a)))
m = len(unique_vals)
val_to_rank = {val: i + 1 for i, val in enumerate(unique_vals)}
b_discrete = [val_to_rank[val] for val in a]
bit = [0] * (m + 1)
def add(pos, val):
while pos <= m:
bit[pos] += val
pos += pos & -pos
def query(pos):
res = 0
while pos > 0:
res += bit[pos]
pos -= pos & -pos
return res
ans = 0
right_le = [0] * n
# 第一遍:从右向左
for i in range(n - 1, -1, -1):
less_count = query(b_discrete[i] - 1)
ans += less_count * (less_count - 1) // 2
right_le[i] = query(b_discrete[i])
add(b_discrete[i], 1)
# 第二遍:从左向右
bit = [0] * (m + 1) # 重置 bit
for j in range(n):
left_total = j
left_le = query(b_discrete[j])
left_greater = left_total - left_le
ans -= left_greater * right_le[j]
add(b_discrete[j], 1)
print(ans)
solve()
算法及复杂度
- 算法:容斥原理 + 树状数组。通过两次 O(
) 的遍历,第一次计算一个包含目标解的超集,并预计算辅助信息;第二次减去多算的部分,从而得到精确解。
- 时间复杂度:
。离散化需要
。两次主循环都各需要 O(
) 次树状数组操作,总复杂度为
。
- 空间复杂度:
。用于存储输入数组、离散化后的数组、辅助数组以及树状数组。