奶油蛋糕的进阶配方

[题目链接](https://www.nowcoder.com/practice/89d12d153bb54d39be4b1e53ad485cb4)

思路

本题需要在门槛约束、体力约束和数量约束下,选择一组蛋糕使最终甜度等级最大,并在此基础上使制作数量最少。

确定制作顺序

首先思考一个关键问题:如果已经决定了要制作哪些蛋糕,最优的制作顺序是什么?

制作蛋糕只会增加甜度等级,越早积累甜度,越容易满足后续蛋糕的门槛。因此,按门槛从小到大的顺序制作一定是最优的——如果一个方案可行,那么把它按门槛排序后仍然可行。

转化为选择问题

将所有蛋糕按门槛排序后,问题变成:从排好序的蛋糕序列中选出至多 个,总体力消耗不超过 ,使得每个被选中的蛋糕在轮到它时,累计甜度已经达到其门槛,且总甜度增益最大。

为什么不能用普通背包?

体力值可能很大,无法直接作为 DP 维度开数组。我们需要一个不依赖体力值大小的方法。

Pareto 前沿 DP

对于"选了恰好 个蛋糕"这一状态,我们关心两个量:已消耗的总体力当前甜度等级。若存在两个状态 ,满足 ,则 支配,永远不会比 更优。

因此,对每个 ,只需维护一组 Pareto 最优的状态:按体力升序排列,甜度严格递增。任何被支配的状态都可以丢弃。

转移过程

按门槛从小到大处理每个蛋糕 (门槛 ,增益 ,消耗 )。从大到小枚举 (和 0-1 背包一样逆序枚举,防止同一蛋糕被选多次):

  • 遍历 中的每个 Pareto 状态
  • (满足门槛)且 (体力足够),产生新状态
  • 用二分查找将新状态插入 的 Pareto 前沿,并剔除被它支配的旧状态。

提取答案

遍历所有 ),取每个前沿中最大的甜度值(即最后一个元素)。全局最大甜度即为答案的第一个数;若多个 达到相同最大甜度,取最小的 作为第二个数。

样例演示

5 种蛋糕,初始甜度 0,体力 100,最多做 3 个。按门槛排序后穷举所有组合,最优方案是选第 1、3、4 种蛋糕(门槛 0、5、8),消耗体力 ,甜度变化 ,输出 19 3

代码

#include <bits/stdc++.h>
using namespace std;
int main(){
    int n;
    long long s, p, k;
    scanf("%d%lld%lld%lld", &n, &s, &p, &k);
    vector<long long> x(n), y(n), c(n);
    for(int i=0;i<n;i++) scanf("%lld",&x[i]);
    for(int i=0;i<n;i++) scanf("%lld",&y[i]);
    for(int i=0;i<n;i++) scanf("%lld",&c[i]);

    vector<int> idx(n);
    iota(idx.begin(), idx.end(), 0);
    sort(idx.begin(), idx.end(), [&](int a, int b){ return x[a]<x[b]; });

    int K = (int)min((long long)n, k);

    // dp[j] = Pareto front of (cost, sweetness), sorted by cost asc, sweetness strictly increasing
    vector<vector<pair<long long,long long>>> dp(K+1);
    dp[0].push_back({0, s});

    for(int ii=0; ii<n; ii++){
        int i = idx[ii];
        long long xi=x[i], yi=y[i], ci=c[i];
        for(int j=min(ii+1,K); j>=1; j--){
            auto& prev = dp[j-1];
            auto& cur = dp[j];
            for(int pi=0; pi<(int)prev.size(); pi++){
                long long co = prev[pi].first, sw = prev[pi].second;
                if(sw >= xi && co+ci <= p){
                    long long nc = co+ci, ns = sw+yi;
                    auto it = upper_bound(cur.begin(), cur.end(), make_pair(nc, LLONG_MAX));
                    if(it != cur.begin() && (it-1)->second >= ns) continue;
                    while(it != cur.end() && it->second <= ns)
                        it = cur.erase(it);
                    cur.insert(it, {nc, ns});
                }
            }
        }
    }

    long long ans_sweet = s;
    int ans_count = 0;
    for(int j=1; j<=K; j++){
        if(!dp[j].empty()){
            long long sw = dp[j].back().second;
            if(sw > ans_sweet || (sw == ans_sweet && j < ans_count)){
                ans_sweet = sw;
                ans_count = j;
            }
        }
    }
    printf("%lld %d\n", ans_sweet, ans_count);
}
import java.util.*;
import java.io.*;

public class Main {
    public static void main(String[] args) throws Exception {
        BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
        StringTokenizer st = new StringTokenizer(br.readLine());
        int n = Integer.parseInt(st.nextToken());
        long s = Long.parseLong(st.nextToken());
        long p = Long.parseLong(st.nextToken());
        long k = Long.parseLong(st.nextToken());

        long[] x = new long[n], y = new long[n], c = new long[n];
        st = new StringTokenizer(br.readLine());
        for (int i = 0; i < n; i++) x[i] = Long.parseLong(st.nextToken());
        st = new StringTokenizer(br.readLine());
        for (int i = 0; i < n; i++) y[i] = Long.parseLong(st.nextToken());
        st = new StringTokenizer(br.readLine());
        for (int i = 0; i < n; i++) c[i] = Long.parseLong(st.nextToken());

        Integer[] idx = new Integer[n];
        for (int i = 0; i < n; i++) idx[i] = i;
        Arrays.sort(idx, (a, b) -> Long.compare(x[a], x[b]));

        int K = (int) Math.min(n, k);

        @SuppressWarnings("unchecked")
        ArrayList<long[]>[] dp = new ArrayList[K + 1];
        for (int j = 0; j <= K; j++) dp[j] = new ArrayList<>();
        dp[0].add(new long[]{0, s});

        for (int ii = 0; ii < n; ii++) {
            int i = idx[ii];
            long xi = x[i], yi = y[i], ci = c[i];
            for (int j = Math.min(ii + 1, K); j >= 1; j--) {
                ArrayList<long[]> prev = dp[j - 1];
                ArrayList<long[]> cur = dp[j];
                for (int pi = 0; pi < prev.size(); pi++) {
                    long co = prev.get(pi)[0], sw = prev.get(pi)[1];
                    if (sw >= xi && co + ci <= p) {
                        long nc = co + ci, ns = sw + yi;
                        int pos = upperBound(cur, nc);
                        if (pos > 0 && cur.get(pos - 1)[1] >= ns) continue;
                        int end = pos;
                        while (end < cur.size() && cur.get(end)[1] <= ns) end++;
                        cur.subList(pos, end).clear();
                        cur.add(pos, new long[]{nc, ns});
                    }
                }
            }
        }

        long ansSw = s;
        int ansCnt = 0;
        for (int j = 1; j <= K; j++) {
            if (!dp[j].isEmpty()) {
                long sw = dp[j].get(dp[j].size() - 1)[1];
                if (sw > ansSw || (sw == ansSw && j < ansCnt)) {
                    ansSw = sw;
                    ansCnt = j;
                }
            }
        }
        System.out.println(ansSw + " " + ansCnt);
    }

    static int upperBound(ArrayList<long[]> list, long cost) {
        int lo = 0, hi = list.size();
        while (lo < hi) {
            int mid = (lo + hi) / 2;
            if (list.get(mid)[0] <= cost) lo = mid + 1;
            else hi = mid;
        }
        return lo;
    }
}
import sys
from bisect import bisect_right

def main():
    input_data = sys.stdin.buffer.read().split()
    ptr = 0
    n = int(input_data[ptr]); ptr+=1
    s = int(input_data[ptr]); ptr+=1
    p = int(input_data[ptr]); ptr+=1
    k = int(input_data[ptr]); ptr+=1

    x = [int(input_data[ptr+i]) for i in range(n)]; ptr+=n
    y = [int(input_data[ptr+i]) for i in range(n)]; ptr+=n
    c = [int(input_data[ptr+i]) for i in range(n)]; ptr+=n

    idx = sorted(range(n), key=lambda i: x[i])
    K = min(n, k)

    # dp[j] = Pareto front of (cost, sweetness), sorted by cost asc, sweetness strictly increasing
    dp = [[] for _ in range(K+1)]
    dp[0] = [(0, s)]

    for ii in range(n):
        i = idx[ii]
        xi, yi, ci = x[i], y[i], c[i]
        for j in range(min(ii+1, K), 0, -1):
            prev = dp[j-1]
            cur = dp[j]
            new_points = []
            for co, sw in prev:
                if sw >= xi and co + ci <= p:
                    new_points.append((co+ci, sw+yi))
            for nc, ns in new_points:
                pos = bisect_right(cur, (nc, float('inf')))
                if pos > 0 and cur[pos-1][1] >= ns:
                    continue
                end = pos
                while end < len(cur) and cur[end][1] <= ns:
                    end += 1
                cur[pos:end] = [(nc, ns)]

    ans_sw = s
    ans_cnt = 0
    for j in range(1, K+1):
        if dp[j]:
            sw = dp[j][-1][1]
            if sw > ans_sw or (sw == ans_sw and j < ans_cnt):
                ans_sw = sw
                ans_cnt = j

    print(ans_sw, ans_cnt)

main()
const readline = require('readline');
const rl = readline.createInterface({ input: process.stdin });
const lines = [];
rl.on('line', l => lines.push(l.trim()));
rl.on('close', () => {
    const [n, s, p, k] = lines[0].split(' ').map(Number);
    const x = lines[1].split(' ').map(Number);
    const y = lines[2].split(' ').map(Number);
    const c = lines[3].split(' ').map(Number);

    const idx = Array.from({length: n}, (_, i) => i);
    idx.sort((a, b) => x[a] - x[b]);

    const K = Math.min(n, k);

    const dp = Array.from({length: K + 1}, () => []);
    dp[0].push([0, s]);

    function upperBound(arr, cost) {
        let lo = 0, hi = arr.length;
        while (lo < hi) {
            const mid = (lo + hi) >> 1;
            if (arr[mid][0] <= cost) lo = mid + 1;
            else hi = mid;
        }
        return lo;
    }

    for (let ii = 0; ii < n; ii++) {
        const i = idx[ii];
        const xi = x[i], yi = y[i], ci = c[i];
        for (let j = Math.min(ii + 1, K); j >= 1; j--) {
            const prev = dp[j - 1];
            const cur = dp[j];
            for (let pi = 0; pi < prev.length; pi++) {
                const co = prev[pi][0], sw = prev[pi][1];
                if (sw >= xi && co + ci <= p) {
                    const nc = co + ci, ns = sw + yi;
                    let pos = upperBound(cur, nc);
                    if (pos > 0 && cur[pos - 1][1] >= ns) continue;
                    let end = pos;
                    while (end < cur.length && cur[end][1] <= ns) end++;
                    cur.splice(pos, end - pos, [nc, ns]);
                }
            }
        }
    }

    let ansSw = s, ansCnt = 0;
    for (let j = 1; j <= K; j++) {
        if (dp[j].length > 0) {
            const sw = dp[j][dp[j].length - 1][1];
            if (sw > ansSw || (sw === ansSw && j < ansCnt)) {
                ansSw = sw;
                ansCnt = j;
            }
        }
    }
    console.log(ansSw + ' ' + ansCnt);
});

复杂度分析

为蛋糕数量,

  • 时间复杂度,其中 是 Pareto 前沿的平均大小。由于支配关系的剪枝,前沿通常远小于理论上界,实际运行效率很高。
  • 空间复杂度,存储每个 对应的 Pareto 前沿。