题目链接

动态最长连续子序列

题目描述

给定一个长度为 的非负整数数列 和一个 的排列 。排列 表示一个“摧毁”序列,在第 步,位于原数列第 个位置的数将被摧毁。

每当一个数被摧毁后,你需要找出当前数列中未被摧毁的数所构成的和最大的连续子序列。如果当前数列中已没有未被摧毁的数,则最大和为 。你需要输出每一步操作后的最大和。

解题思路

  1. 逆向思维

    直接模拟“摧毁”过程非常复杂,因为每次摧毁都会将一个连续的段断裂成两个或零个,维护最大和非常困难。

    我们可以逆向思考这个问题。不按顺序摧毁元素,而是从一个全被摧毁的空序列开始,按照摧毁顺序的逆序,一步步地将元素“恢复”回来。

    • 原问题第 步:摧毁 ,求最大和。
    • 逆问题第 步:在序列中添加 ,求最大和。

    这样,在逆向过程的第 步计算出的最大和,就对应原问题中第 步(即摧毁了前 个元素后)的结果。

  2. 使用并查集维护连续段

    当我们在逆向过程中添加一个元素 时,它可能会与它左边()或右边()已经存在的连续段合并。

    • 如果 的左边和右边都已经是恢复的元素,那么添加 会将左边的连续段、 本身、右边的连续段这三者合并成一个更大的连续段。
    • 如果只有一侧是已恢复的元素,则 会加入那一侧的连续段。
    • 如果两侧都没有已恢复的元素,则 自己形成一个新的、长度为1的连续段。

    这个“合并相邻集合”的操作正是并查集 (Union-Find) 数据结构的专长。我们可以用并查集来维护这些动态形成的连续段。

  3. 算法实现

    1. 初始化

      • 建立一个并查集,每个元素(数组下标 0n-1)初始时都是一个独立的集合。
      • 为并查集的每个集合(根节点)关联一个 sum 属性,用于记录该集合(连续段)中所有元素的总和。
      • 设置一个 active 数组,标记哪些位置的元素已经被“恢复”。
    2. 逆向添加

      • 我们从摧毁序列 的末尾开始,向前遍历。
      • 在第 步( 遍历到 ),我们将要恢复的元素是
      • 在恢复它之前,记录下当前全局的最大连续子序列和。这个值就是原问题摧毁了前 个元素后的答案。
      • 恢复元素
        • 将其标记为 active
        • 其所在的集合 sum 初始化为 本身。
        • 检查其左邻居 是否 active。如果是,则将 所在的集合合并 (unite)。合并时,需要将两个集合的 sum 相加。
        • 检查其右邻居 是否 active。如果是,则将 所在的集合合并。
        • 更新全局的最大连续子序列和。
    3. 输出

      • 将过程中记录的所有答案倒序输出,即为原问题的最终解。

代码

#include <iostream>
#include <vector>
#include <numeric>
#include <algorithm>

using namespace std;

vector<int> parent;
vector<long long> set_sum;

int find_set(int v) {
    if (v == parent[v])
        return v;
    return parent[v] = find_set(parent[v]);
}

void unite_sets(int a, int b) {
    a = find_set(a);
    b = find_set(b);
    if (a != b) {
        // 为了简化,这里让下标小的做根
        if (a > b) swap(a, b);
        parent[b] = a;
        set_sum[a] += set_sum[b];
    }
}

int main() {
    ios_base::sync_with_stdio(false);
    cin.tie(NULL);

    int n;
    cin >> n;

    vector<long long> a(n);
    for (int i = 0; i < n; ++i) cin >> a[i];

    vector<int> p(n);
    for (int i = 0; i < n; ++i) cin >> p[i];

    parent.resize(n);
    iota(parent.begin(), parent.end(), 0);
    set_sum.assign(n, 0);
    vector<bool> active(n, false);

    vector<long long> results;
    long long current_max_sum = 0;

    for (int i = n - 1; i >= 0; --i) {
        results.push_back(current_max_sum);
        
        int idx = p[i] - 1;
        active[idx] = true;
        set_sum[idx] = a[idx];
        
        if (idx > 0 && active[idx - 1]) {
            unite_sets(idx, idx - 1);
        }
        if (idx < n - 1 && active[idx + 1]) {
            unite_sets(idx, idx + 1);
        }
        
        int root = find_set(idx);
        current_max_sum = max(current_max_sum, set_sum[root]);
    }

    reverse(results.begin(), results.end());
    for (long long res : results) {
        cout << res << "\n";
    }

    return 0;
}
import java.util.Scanner;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;

public class Main {
    private static int[] parent;
    private static long[] setSum;

    private static int find(int i) {
        if (parent[i] == i) {
            return i;
        }
        return parent[i] = find(parent[i]);
    }

    private static void unite(int i, int j) {
        int rootI = find(i);
        int rootJ = find(j);
        if (rootI != rootJ) {
            if (rootI < rootJ) { // Keep smaller index as root
                parent[rootJ] = rootI;
                setSum[rootI] += setSum[rootJ];
            } else {
                parent[rootI] = rootJ;
                setSum[rootJ] += setSum[rootI];
            }
        }
    }

    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);
        int n = sc.nextInt();
        long[] a = new long[n];
        for (int i = 0; i < n; i++) {
            a[i] = sc.nextLong();
        }
        int[] p = new int[n];
        for (int i = 0; i < n; i++) {
            p[i] = sc.nextInt();
        }

        parent = new int[n];
        setSum = new long[n];
        for (int i = 0; i < n; i++) {
            parent[i] = i;
        }
        boolean[] active = new boolean[n];
        
        List<Long> results = new ArrayList<>();
        long currentMaxSum = 0;

        for (int i = n - 1; i >= 0; i--) {
            results.add(currentMaxSum);
            
            int idx = p[i] - 1;
            active[idx] = true;
            setSum[idx] = a[idx];

            if (idx > 0 && active[idx - 1]) {
                unite(idx, idx - 1);
            }
            if (idx < n - 1 && active[idx + 1]) {
                unite(idx, idx + 1);
            }
            
            int root = find(idx);
            currentMaxSum = Math.max(currentMaxSum, setSum[root]);
        }

        Collections.reverse(results);
        for (long res : results) {
            System.out.println(res);
        }
    }
}
import sys

def find(parent, i):
    if parent[i] == i:
        return i
    parent[i] = find(parent, parent[i])
    return parent[i]

def unite(parent, set_sum, i, j):
    root_i = find(parent, i)
    root_j = find(parent, j)
    if root_i != root_j:
        # Keep smaller index as root to be deterministic
        if root_i < root_j:
            parent[root_j] = root_i
            set_sum[root_i] += set_sum[root_j]
        else:
            parent[root_i] = root_j
            set_sum[root_j] += set_sum[root_i]

def solve():
    n_str = sys.stdin.readline()
    if not n_str: return
    n = int(n_str)
    
    a = list(map(int, sys.stdin.readline().split()))
    p = list(map(int, sys.stdin.readline().split()))
    
    parent = list(range(n))
    set_sum = [0] * n
    active = [False] * n
    
    results = []
    current_max_sum = 0
    
    for i in range(n - 1, -1, -1):
        results.append(current_max_sum)
        
        idx = p[i] - 1
        active[idx] = True
        set_sum[idx] = a[idx]
        
        if idx > 0 and active[idx - 1]:
            unite(parent, set_sum, idx, idx - 1)
        if idx < n - 1 and active[idx + 1]:
            unite(parent, set_sum, idx, idx + 1)
            
        root = find(parent, idx)
        current_max_sum = max(current_max_sum, set_sum[root])
        
    for res in reversed(results):
        print(res)

solve()

算法及复杂度

  • 算法:逆向思维 + 并查集
  • 时间复杂度。我们逆序遍历 个元素,每一步都执行常数次的并查集查找和合并操作。带有路径压缩和按秩(或大小)合并的并查集操作的平均时间复杂度为 (反阿克曼函数),其增长极其缓慢,可近似视为常数。
  • 空间复杂度,用于存储原数组、摧毁序列、并查集(父节点数组和总和数组)以及结果数组。