题目链接

小红的显存清理挑战

题目描述

小红需要从 个张量中挑选一部分清理,以释放至少 单位的显存空间。每个张量 有空间 和两种清理方式(成本分别为 )。小红会选择成本较低的一种,即 。求满足释放空间 的最小总成本。若无法满足,输出 error

解题思路

由于 可达 ,传统的以空间为下标的 DP 数组会造成内存溢出。我们需要使用稀疏 DP,并配合正确的支配解剪枝

  1. 有效成本:每个张量的清理成本为
  2. 状态定义:维护一个非支配状态集 dp,存储 (space, cost)
    • 一个状态 若能被另一个状态 替代(即 ),则 是被支配的,应剔除。
  3. 状态转移与合并
    • 遍历每个张量,生成新状态集。将所有空间 的状态空间统一设为
    • 合并新旧状态集后,按空间 降序排序。
    • 遍历排序后的列表,维护一个 min_cost。只有当当前状态的 时,该状态才是有效的(因为它在更小的空间下提供了更低的成本)。
  4. 最终检查:若最终状态集中不存在空间达到 的状态,输出 error

代码

#include <iostream>
#include <vector>
#include <algorithm>

using namespace std;

struct State {
    long long s, c;
};

// 比较函数:按空间降序,空间相同时按成本升序
bool compareStates(const State& a, const State& b) {
    if (a.s != b.s) return a.s > b.s;
    return a.c < b.c;
}

int main() {
    long long m;
    int n;
    cin >> m >> n;

    vector<long long> s(n), c1(n), c2(n);
    for (int i = 0; i < n; ++i) cin >> s[i];
    for (int i = 0; i < n; ++i) cin >> c1[i];
    for (int i = 0; i < n; ++i) cin >> c2[i];

    vector<State> dp;
    dp.push_back({0, 0});

    for (int i = 0; i < n; ++i) {
        long long cur_s = s[i];
        long long cur_c = min(c1[i], c2[i]);
        
        int original_size = dp.size();
        for (int j = 0; j < original_size; ++j) {
            long long ns = min(m, dp[j].s + cur_s);
            dp.push_back({ns, dp[j].c + cur_c});
        }

        // 排序并剪枝
        sort(dp.begin(), dp.end(), compareStates);
        
        vector<State> next_dp;
        long long min_c = -1;
        for (auto& st : dp) {
            if (min_c == -1 || st.c < min_c) {
                if (!next_dp.empty() && next_dp.back().s == st.s) {
                    // 同一空间取最小成本,由于已按 c 升序,此处不需操作
                } else {
                    next_dp.push_back(st);
                    min_c = st.c;
                }
            }
        }
        dp = move(next_dp);
    }

    long long ans = -1;
    for (auto& st : dp) {
        if (st.s >= m) {
            if (ans == -1 || st.c < ans) ans = st.c;
        }
    }

    if (ans == -1) cout << "error" << endl;
    else cout << ans << endl;

    return 0;
}
import java.util.*;

public class Main {
    static class State {
        long s, c;
        State(long s, long c) { this.s = s; this.c = c; }
    }

    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);
        long m = sc.nextLong();
        int n = sc.nextInt();
        long[] s = new long[n];
        for (int i = 0; i < n; i++) s[i] = sc.nextLong();
        long[] c1 = new long[n];
        for (int i = 0; i < n; i++) c1[i] = sc.nextLong();
        long[] c2 = new long[n];
        for (int i = 0; i < n; i++) c2[i] = sc.nextLong();

        List<State> dp = new ArrayList<>();
        dp.add(new State(0, 0));

        for (int i = 0; i < n; i++) {
            long curS = s[i];
            long curC = Math.min(c1[i], c2[i]);
            
            int size = dp.size();
            for (int j = 0; j < size; j++) {
                State st = dp.get(j);
                dp.add(new State(Math.min(m, st.s + curS), st.c + curC));
            }

            // 按空间降序,成本升序
            dp.sort((a, b) -> {
                if (a.s != b.s) return Long.compare(b.s, a.s);
                return Long.compare(a.c, b.c);
            });

            List<State> nextDp = new ArrayList<>();
            long minC = Long.MAX_VALUE;
            for (State st : dp) {
                if (st.c < minC) {
                    if (!nextDp.isEmpty() && nextDp.get(nextDp.size() - 1).s == st.s) {
                        continue;
                    }
                    nextDp.add(st);
                    minC = st.c;
                }
            }
            dp = nextDp;
        }

        long ans = -1;
        for (State st : dp) {
            if (st.s >= m) {
                if (ans == -1 || st.c < ans) ans = st.c;
            }
        }
        System.out.println(ans == -1 ? "error" : ans);
    }
}
def solve():
    m = int(input())
    n = int(input())
    s_list = list(map(int, input().split()))
    c1 = list(map(int, input().split()))
    c2 = list(map(int, input().split()))

    # dp 存储 (space, cost)
    dp = [(0, 0)]

    for i in range(n):
        si = s_list[i]
        ci = min(c1[i], c2[i])
        
        # 生成新状态
        new_states = []
        for s, c in dp:
            new_states.append((min(m, s + si), c + ci))
        
        dp.extend(new_states)
        
        # 降序排列空间,升序排列成本
        # 这样我们可以通过线性扫描剔除被支配的解
        dp.sort(key=lambda x: (-x[0], x[1]))
        
        next_dp = []
        min_c = float('inf')
        for s, c in dp:
            if c < min_c:
                if next_dp and next_dp[-1][0] == s:
                    continue
                next_dp.append((s, c))
                min_c = c
        dp = next_dp

    ans = -1
    for s, c in dp:
        if s >= m:
            if ans == -1 or c < ans:
                ans = c
    
    if ans == -1:
        print("error")
    else:
        print(ans)

solve()

算法及复杂度

  • 算法:稀疏 0/1 背包(基于非支配解集的单调性优化)。
  • 时间复杂度:,其中 为每一轮保留的非支配状态数。由于剔除了大量无效状态并锁定了空间上限 ,实际运行效率很高。
  • 空间复杂度:。仅存储有效状态,解决了内存溢出问题。