H题 - 关于交换求和次序的一点思考

前言

本场比赛的 H 题可以使用 交换求和次序 的方法从数学角度上来直观理解贡献。我之前一直不会交换求和次序来着(拖走)。

权值计算

链接:2026牛客寒假算法基础集训营2 - H

来源:牛客网


题目描述

 function f(l, r, s)
     distinct ← ∅
     total ← 0
     current_count ← 0
     for i ← l to r do
         if s[i] ∉ distinct then
             current_count ← current_count + 1
             distinct ← distinct ∪ {s[i]}
         end if
         total ← total + current_count
     end for
     return total
 end function

如上是一段计算数组权值的伪代码,通过调用 计算一个长度为 的数组 的权值,现在有一个长度为 的数组 ,请你求出所有非空 子数组 的权值之和。

【名词解释】

子数组:从原数组中,连续的选择一段元素(可以全选,可以不选)得到的新数组。


输入描述

每个测试文件均包含多组测试数据。第一行输入一个整数 ,代表数据组数,每组测试数据描述如下:

第一行输入一个整数 ,表示数组长度。

第二行输入 个整数 ,表示数组中的元素。

除此之外,保证单个测试文件的 之和不超过


输出描述

对于每一组测试数据,新起一行输出一个整数,表示所有子数组的权值之和。


示例 1

输入

2
3
1 3 1
6
1 1 4 5 1 4

输出

14
102

思路

把上面伪代码中函数的作用转为成人话就是:求出 范围上所有前缀中,不同元素个数的累加和。例如 ,它的前缀分别是:

C++ 不难实现上述伪代码中的函数功能,代码如下:

int f(int l, int r, std::vector<int>& arr) {
    std::set<int> s;
    int sum = 0;
    for (int i = l; i <= r; i++) {
        s.insert(arr[i]);
        sum += s.size();
    }
    return sum;
}

我们需要枚举所有子数组,注意我这里的数组是 1-based

int BF(std::vector<int>& arr) {
    int n = arr.size() - 1;
    int sum = 0;
    for (int l = 1; l <= n; l++) {
        for (int r = l; r <= n; r++) {
            sum += f(l, r, arr);
        }
    } 
    return sum;
}

可以用数学语言来重新表述这道题,其中 表示 范围上不同元素的个数。

不难看出当前式子的时间复杂度为 ,所以我们需要化简这个累加式,一个常见的技巧就是交换求和次序。比如说我们有这个恒等式成立:

上面的等式就是说改变了计算顺序,依旧可以得到相同的答案。从直观上可以用矩阵来理解,左边的式子就是按行累加的结果,右边的式子就是按列累加的结果。

但是每次推导交换次序后的等式都画一个图的话,会很麻烦,并且这种方法无法适用于一些复杂的情况。这个时候我们就需要换一种更本质思路来解决才行。我们考虑上面等式中 的范围,不难得到:

对于左边的式子而言,它的意思可以看作当 上移动时,固定 ,那么 的移动范围就是 ;同样的也可以看作当 上移动,固定 ,那么 的移动范围就是 。这样就能直接写出交换求和次序之后的式子。

让我们重新回到题目本身,要化简这个三重求和式,那么先化简内层的那个二重求和式。

上移动时,固定 ,那么 上移动。根据刚刚讲的方法,可以立即得到:

带回原式,于是有:

上移动的时,固定 ,那么 上移动。可以立即得到:

最终我们便把原式化简,其中 表示 范围上不同元素的个数:

为了把 的时间复杂度进一步优化,我们考虑内层的求和式可不可以优化。不难看出内层求和式实际上是一个前缀和,考虑能不能从 的值推出 的值。我们尝试先模拟一下,这里还是以 为例:








不难看出这个前缀和每次增加的值等于这个数 当前位置 减去这个数 上一次出现的位置 。令 为当前位置, 表示 上一次出现的位置,定义函数:

则有状态转移方程

代码

#include <bits/stdc++.h>
#define int long long

void solve() {
    int n; std::cin >> n;
    std::vector<int> a(n + 1);
    std::map<int, int> last;
    for (int i = 1; i <= n; i++) std::cin >> a[i];
    int ans = 0, sum = 0;
    for (int k = 1; k <= n; k++) {
        sum += k - last[a[k]];
        ans += (n - k + 1) * sum; 
        last[a[k]] = k;
    }    
    std::cout << ans << '\n';
}

signed main() {
    std::ios::sync_with_stdio(false); std::cin.tie(nullptr);
    int T = 1;
    std::cin >> T;
    while (T--) solve();
    return 0;
}