题目链接

小美的序列问题

题目描述

给定一个由 个整数组成的数组 ,计算其中有多少个三元组 满足

输入:

  • 第一行输入一个整数 (),代表数组中的元素个数。
  • 第二行输入 个整数 ()。

输出:

  • 输出一个整数,表示满足条件的三元组个数。

解题思路

本题是一道复杂的计数问题,由于 的范围较大,必须使用 O() 的高效算法。正确的解法是基于容斥原理 (Inclusion-Exclusion Principle)树状数组 (Fenwick Tree)

直接统计满足 a[i] > a[k] > a[j] 的三元组非常困难。因此,我们采用容斥思想:先统计一个更大、更容易计算的集合,然后减去其中不符合条件的部分。

第一步:计算超集(Inclusion) 我们首先计算所有满足 i < j < ka[i] 是三者中最大值的三元组,即 a[i] > a[j]a[i] > a[k]。 这个集合包含了我们想要的 a[i] > a[k] > a[j],但也包含了我们不想要的 a[i] > a[j] >= a[k]

我们可以通过一次从右向左的遍历,并使用树状数组来高效计算这个超集的大小:

  1. 初始化 ans = 0 和一个空的树状数组 bit
  2. 遍历 in-10: a. 查询 bit 中有多少个值小于 a[i]。设这个数量为 cnt。这些值都在 i 的右侧。 b. 从这 cnt 个元素中任意挑选两个作为 jk,都能满足 a[i] > a[j]a[i] > a[k]。因此,i 对总数的贡献是组合数 。 c. 将 a[i] 的出现次数在 bit 中加 1。

第二步:减去多余部分(Exclusion) 现在,我们需要从 ans 中减去所有不符合条件的三元组,即满足 i < j < ka[i] > a[j] >= a[k] 的部分。

我们可以通过固定中间元素 j 来统计这部分的大小。对于每个 j,其贡献为: (在 j 左侧大于 a[j] 的元素个数) * (在 j 右侧小于等于 a[j] 的元素个数)

这两部分的数量可以通过两次遍历和树状数组高效计算:

  1. 预计算:在第一步从右向左遍历时,我们可以顺便计算并存储 (在 j 右侧小于等于 a[j] 的元素个数) 到一个辅助数组 b 中。
  2. 主计算:再进行一次从左向右的遍历(以 j 为当前索引),使用一个新的树状数组 bit2 实时计算 (在 j 左侧大于 a[j] 的元素个数),然后乘以预计算好的 b[j],从 ans 中减去。

最终算法流程:

  1. 离散化 a 数组。
  2. 第一遍 (从右向左): a. 用 bit1 查询 i 右侧小于 a[i] 的个数 cntans += cnt*(cnt-1)/2。 b. 用 bit1 查询 i 右侧小于等于 a[i] 的个数,存入 b[i]。 c. 更新 bit1
  3. 第二遍 (从左向右): a. 用 bit2 查询 j 左侧大于 a[j] 的个数 cnt_left_greater。 b. ans -= cnt_left_greater * b[j]。 c. 更新 bit2

代码

#include <iostream>
#include <vector>
#include <algorithm>
#include <map>

using namespace std;

const int MAXN = 200005;
long long bit[MAXN];
int n, m;

void add(int pos, long long val) {
    for (; pos <= m; pos += pos & -pos) {
        bit[pos] += val;
    }
}

long long query(int pos) {
    long long res = 0;
    for (; pos > 0; pos -= pos & -pos) {
        res += bit[pos];
    }
    return res;
}

void clear_bit() {
    for (int i = 0; i <= m; ++i) {
        bit[i] = 0;
    }
}

int main() {
    ios::sync_with_stdio(false);
    cin.tie(NULL);

    cin >> n;
    vector<int> a(n);
    vector<int> unique_vals;
    for (int i = 0; i < n; ++i) {
        cin >> a[i];
        unique_vals.push_back(a[i]);
    }

    sort(unique_vals.begin(), unique_vals.end());
    unique_vals.erase(unique(unique_vals.begin(), unique_vals.end()), unique_vals.end());
    m = unique_vals.size();

    map<int, int> val_to_rank;
    for (int i = 0; i < m; ++i) {
        val_to_rank[unique_vals[i]] = i + 1;
    }

    vector<int> b_discrete(n);
    for (int i = 0; i < n; ++i) {
        b_discrete[i] = val_to_rank[a[i]];
    }

    long long ans = 0;
    vector<long long> right_le(n); // 存储右侧小于等于a[i]的个数
    
    // 第一遍:从右向左,计算超集和辅助数组
    clear_bit();
    for (int i = n - 1; i >= 0; --i) {
        long long less_count = query(b_discrete[i] - 1);
        ans += less_count * (less_count - 1) / 2;
        right_le[i] = query(b_discrete[i]);
        add(b_discrete[i], 1);
    }

    // 第二遍:从左向右,减去多余部分
    clear_bit();
    for (int j = 0; j < n; ++j) {
        long long left_total = j;
        long long left_le = query(b_discrete[j]);
        long long left_greater = left_total - left_le;
        ans -= left_greater * right_le[j];
        add(b_discrete[j], 1);
    }

    cout << ans << '\n';

    return 0;
}
import java.util.*;

public class Main {
    static int m;
    static long[] bit;

    static void add(int pos, long val) {
        for (; pos <= m; pos += pos & -pos) {
            bit[pos] += val;
        }
    }

    static long query(int pos) {
        long res = 0;
        for (; pos > 0; pos -= pos & -pos) {
            res += bit[pos];
        }
        return res;
    }
    
    static void clearBit() {
        Arrays.fill(bit, 0);
    }

    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);
        int n = sc.nextInt();
        int[] a = new int[n];
        
        Set<Integer> uniqueSet = new HashSet<>();
        for (int i = 0; i < n; i++) {
            a[i] = sc.nextInt();
            uniqueSet.add(a[i]);
        }
        
        List<Integer> sortedUnique = new ArrayList<>(uniqueSet);
        Collections.sort(sortedUnique);
        
        m = sortedUnique.size();
        Map<Integer, Integer> valToRank = new HashMap<>();
        for (int i = 0; i < m; i++) {
            valToRank.put(sortedUnique.get(i), i + 1);
        }
        
        int[] bDiscrete = new int[n];
        for (int i = 0; i < n; i++) {
            bDiscrete[i] = valToRank.get(a[i]);
        }
        
        bit = new long[m + 1];
        long ans = 0;
        long[] rightLe = new long[n];

        // 第一遍:从右向左
        clearBit();
        for (int i = n - 1; i >= 0; i--) {
            long lessCount = query(bDiscrete[i] - 1);
            ans += lessCount * (lessCount - 1) / 2;
            rightLe[i] = query(bDiscrete[i]);
            add(bDiscrete[i], 1);
        }
        
        // 第二遍:从左向右
        clearBit();
        for (int j = 0; j < n; j++) {
            long leftTotal = j;
            long leftLe = query(bDiscrete[j]);
            long leftGreater = leftTotal - leftLe;
            ans -= leftGreater * rightLe[j];
            add(bDiscrete[j], 1);
        }
        
        System.out.println(ans);
    }
}
def solve():
    n = int(input())
    a = list(map(int, input().split()))

    # 离散化
    unique_vals = sorted(list(set(a)))
    m = len(unique_vals)
    val_to_rank = {val: i + 1 for i, val in enumerate(unique_vals)}
    b_discrete = [val_to_rank[val] for val in a]

    bit = [0] * (m + 1)
    def add(pos, val):
        while pos <= m:
            bit[pos] += val
            pos += pos & -pos

    def query(pos):
        res = 0
        while pos > 0:
            res += bit[pos]
            pos -= pos & -pos
        return res

    ans = 0
    right_le = [0] * n

    # 第一遍:从右向左
    for i in range(n - 1, -1, -1):
        less_count = query(b_discrete[i] - 1)
        ans += less_count * (less_count - 1) // 2
        right_le[i] = query(b_discrete[i])
        add(b_discrete[i], 1)

    # 第二遍:从左向右
    bit = [0] * (m + 1) # 重置 bit
    for j in range(n):
        left_total = j
        left_le = query(b_discrete[j])
        left_greater = left_total - left_le
        ans -= left_greater * right_le[j]
        add(b_discrete[j], 1)

    print(ans)

solve()

算法及复杂度

  • 算法:容斥原理 + 树状数组。通过两次 O() 的遍历,第一次计算一个包含目标解的超集,并预计算辅助信息;第二次减去多算的部分,从而得到精确解。
  • 时间复杂度:。离散化需要 。两次主循环都各需要 O() 次树状数组操作,总复杂度为
  • 空间复杂度:。用于存储输入数组、离散化后的数组、辅助数组以及树状数组。