题目链接
题目描述
给定一个长度为 的非负整数数列
和一个
到
的排列
。排列
表示一个“摧毁”序列,在第
步,位于原数列第
个位置的数将被摧毁。
每当一个数被摧毁后,你需要找出当前数列中未被摧毁的数所构成的和最大的连续子序列。如果当前数列中已没有未被摧毁的数,则最大和为 。你需要输出每一步操作后的最大和。
解题思路
-
逆向思维
直接模拟“摧毁”过程非常复杂,因为每次摧毁都会将一个连续的段断裂成两个或零个,维护最大和非常困难。
我们可以逆向思考这个问题。不按顺序摧毁元素,而是从一个全被摧毁的空序列开始,按照摧毁顺序的逆序,一步步地将元素“恢复”回来。
- 原问题第
步:摧毁
,求最大和。
- 逆问题第
步:在序列中添加
,求最大和。
这样,在逆向过程的第
步计算出的最大和,就对应原问题中第
步(即摧毁了前
个元素后)的结果。
- 原问题第
-
使用并查集维护连续段
当我们在逆向过程中添加一个元素
时,它可能会与它左边(
)或右边(
)已经存在的连续段合并。
- 如果
的左边和右边都已经是恢复的元素,那么添加
会将左边的连续段、
本身、右边的连续段这三者合并成一个更大的连续段。
- 如果只有一侧是已恢复的元素,则
会加入那一侧的连续段。
- 如果两侧都没有已恢复的元素,则
自己形成一个新的、长度为1的连续段。
这个“合并相邻集合”的操作正是并查集 (Union-Find) 数据结构的专长。我们可以用并查集来维护这些动态形成的连续段。
- 如果
-
算法实现
-
初始化:
- 建立一个并查集,每个元素(数组下标
0
到n-1
)初始时都是一个独立的集合。 - 为并查集的每个集合(根节点)关联一个
sum
属性,用于记录该集合(连续段)中所有元素的总和。 - 设置一个
active
数组,标记哪些位置的元素已经被“恢复”。
- 建立一个并查集,每个元素(数组下标
-
逆向添加:
- 我们从摧毁序列
的末尾开始,向前遍历。
- 在第
步(
从
遍历到
),我们将要恢复的元素是
。
- 在恢复它之前,记录下当前全局的最大连续子序列和。这个值就是原问题摧毁了前
个元素后的答案。
- 恢复元素
:
- 将其标记为
active
。 - 其所在的集合
sum
初始化为本身。
- 检查其左邻居
是否
active
。如果是,则将与
所在的集合合并 (
unite
)。合并时,需要将两个集合的sum
相加。 - 检查其右邻居
是否
active
。如果是,则将与
所在的集合合并。
- 更新全局的最大连续子序列和。
- 将其标记为
- 我们从摧毁序列
-
输出:
- 将过程中记录的所有答案倒序输出,即为原问题的最终解。
-
代码
#include <iostream>
#include <vector>
#include <numeric>
#include <algorithm>
using namespace std;
vector<int> parent;
vector<long long> set_sum;
int find_set(int v) {
if (v == parent[v])
return v;
return parent[v] = find_set(parent[v]);
}
void unite_sets(int a, int b) {
a = find_set(a);
b = find_set(b);
if (a != b) {
// 为了简化,这里让下标小的做根
if (a > b) swap(a, b);
parent[b] = a;
set_sum[a] += set_sum[b];
}
}
int main() {
ios_base::sync_with_stdio(false);
cin.tie(NULL);
int n;
cin >> n;
vector<long long> a(n);
for (int i = 0; i < n; ++i) cin >> a[i];
vector<int> p(n);
for (int i = 0; i < n; ++i) cin >> p[i];
parent.resize(n);
iota(parent.begin(), parent.end(), 0);
set_sum.assign(n, 0);
vector<bool> active(n, false);
vector<long long> results;
long long current_max_sum = 0;
for (int i = n - 1; i >= 0; --i) {
results.push_back(current_max_sum);
int idx = p[i] - 1;
active[idx] = true;
set_sum[idx] = a[idx];
if (idx > 0 && active[idx - 1]) {
unite_sets(idx, idx - 1);
}
if (idx < n - 1 && active[idx + 1]) {
unite_sets(idx, idx + 1);
}
int root = find_set(idx);
current_max_sum = max(current_max_sum, set_sum[root]);
}
reverse(results.begin(), results.end());
for (long long res : results) {
cout << res << "\n";
}
return 0;
}
import java.util.Scanner;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
public class Main {
private static int[] parent;
private static long[] setSum;
private static int find(int i) {
if (parent[i] == i) {
return i;
}
return parent[i] = find(parent[i]);
}
private static void unite(int i, int j) {
int rootI = find(i);
int rootJ = find(j);
if (rootI != rootJ) {
if (rootI < rootJ) { // Keep smaller index as root
parent[rootJ] = rootI;
setSum[rootI] += setSum[rootJ];
} else {
parent[rootI] = rootJ;
setSum[rootJ] += setSum[rootI];
}
}
}
public static void main(String[] args) {
Scanner sc = new Scanner(System.in);
int n = sc.nextInt();
long[] a = new long[n];
for (int i = 0; i < n; i++) {
a[i] = sc.nextLong();
}
int[] p = new int[n];
for (int i = 0; i < n; i++) {
p[i] = sc.nextInt();
}
parent = new int[n];
setSum = new long[n];
for (int i = 0; i < n; i++) {
parent[i] = i;
}
boolean[] active = new boolean[n];
List<Long> results = new ArrayList<>();
long currentMaxSum = 0;
for (int i = n - 1; i >= 0; i--) {
results.add(currentMaxSum);
int idx = p[i] - 1;
active[idx] = true;
setSum[idx] = a[idx];
if (idx > 0 && active[idx - 1]) {
unite(idx, idx - 1);
}
if (idx < n - 1 && active[idx + 1]) {
unite(idx, idx + 1);
}
int root = find(idx);
currentMaxSum = Math.max(currentMaxSum, setSum[root]);
}
Collections.reverse(results);
for (long res : results) {
System.out.println(res);
}
}
}
import sys
def find(parent, i):
if parent[i] == i:
return i
parent[i] = find(parent, parent[i])
return parent[i]
def unite(parent, set_sum, i, j):
root_i = find(parent, i)
root_j = find(parent, j)
if root_i != root_j:
# Keep smaller index as root to be deterministic
if root_i < root_j:
parent[root_j] = root_i
set_sum[root_i] += set_sum[root_j]
else:
parent[root_i] = root_j
set_sum[root_j] += set_sum[root_i]
def solve():
n_str = sys.stdin.readline()
if not n_str: return
n = int(n_str)
a = list(map(int, sys.stdin.readline().split()))
p = list(map(int, sys.stdin.readline().split()))
parent = list(range(n))
set_sum = [0] * n
active = [False] * n
results = []
current_max_sum = 0
for i in range(n - 1, -1, -1):
results.append(current_max_sum)
idx = p[i] - 1
active[idx] = True
set_sum[idx] = a[idx]
if idx > 0 and active[idx - 1]:
unite(parent, set_sum, idx, idx - 1)
if idx < n - 1 and active[idx + 1]:
unite(parent, set_sum, idx, idx + 1)
root = find(parent, idx)
current_max_sum = max(current_max_sum, set_sum[root])
for res in reversed(results):
print(res)
solve()
算法及复杂度
- 算法:逆向思维 + 并查集
- 时间复杂度:
。我们逆序遍历
个元素,每一步都执行常数次的并查集查找和合并操作。带有路径压缩和按秩(或大小)合并的并查集操作的平均时间复杂度为
(反阿克曼函数),其增长极其缓慢,可近似视为常数。
- 空间复杂度:
,用于存储原数组、摧毁序列、并查集(父节点数组和总和数组)以及结果数组。