题目链接

数组取精

题目描述

给定两个长度为 的正整数序列

需要找出一个下标子集 ),该子集被称为“精华子集”,当且仅当同时满足以下两个条件:
a.
b.

你需要找到并输出这样任意一个精华子集。

解题思路

这是一个构造性问题,目标是找到一个满足条件的子集。问题的核心在于,选出的子集 中元素的 之和与 之和,都必须超过各自序列总和的一半。

我们可以采用一种贪心构造的策略。这个策略的核心思想是:首先在一个维度(比如数组 )上取得绝对优势,然后再在另一个维度(数组 )上进行明智的选择,以确保第二个条件也尽可能满足。事实证明,这种策略总是能找到一个解。

算法步骤如下:

  1. 排序:将所有元素的原始下标 根据它们对应的 值进行降序排序。这样,我们就获得了一个下标序列 ,其中

  2. 强制选择最大元:由于 是所有 中的最大值(或之一),它对于满足第一个条件(关于 的和)的贡献最大。因此,我们首先将下标 放入我们的解集 中。

  3. 配对选择:接下来,我们将剩余的排好序的下标从 开始,两两配对:。对于每一对 ,我们比较它们对应的 值,即 。我们选择其中 值较大的那个下标加入解集 。这样做是为了在满足第二个条件(关于 的和)上做出最优的局部选择。

  4. 处理剩余元素

    • 如果 是奇数,那么下标 共有 个(偶数个),它们可以被完美地分成 对。
    • 如果 是偶数,那么 共有 个(奇数个)。配对后会剩下最后一个下标 。由于我们需要构造一个总和超过一半的集合,将这个剩余的元素也加入解集 是一个稳妥的选择。
  5. 输出:最后,将解集 的大小和其中的所有下标(按升序排列)输出即可。

这个构造性算法的巧妙之处在于,它通过在 值上排序来建立一个基本盘,然后通过在配对中选择更优的 值来确保第二个条件也得到满足,从而保证总能构造出一个符合要求的解。

代码

#include <iostream>
#include <vector>
#include <numeric>
#include <algorithm>

using namespace std;

struct Node {
    int a, b, id;
};

int main() {
    int n;
    cin >> n;
    vector<Node> nodes(n);
    for (int i = 0; i < n; ++i) {
        cin >> nodes[i].a;
        nodes[i].id = i + 1;
    }
    for (int i = 0; i < n; ++i) {
        cin >> nodes[i].b;
    }

    sort(nodes.begin(), nodes.end(), [](const Node& x, const Node& y) {
        return x.a > y.a;
    });

    vector<int> result;
    result.push_back(nodes[0].id);

    for (int i = 1; i + 1 < n; i += 2) {
        if (nodes[i].b >= nodes[i + 1].b) {
            result.push_back(nodes[i].id);
        } else {
            result.push_back(nodes[i + 1].id);
        }
    }

    if (n % 2 == 0) {
        result.push_back(nodes[n - 1].id);
    }

    sort(result.begin(), result.end());

    cout << result.size() << endl;
    for (int i = 0; i < result.size(); ++i) {
        cout << result[i] << (i == result.size() - 1 ? "" : " ");
    }
    cout << endl;

    return 0;
}
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Scanner;

class Node implements Comparable<Node> {
    int a, b, id;

    public Node(int a, int b, int id) {
        this.a = a;
        this.b = b;
        this.id = id;
    }

    @Override
    public int compareTo(Node other) {
        return Integer.compare(other.a, this.a); // 降序
    }
}

public class Main {
    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);
        int n = sc.nextInt();
        List<Node> nodes = new ArrayList<>();
        int[] a = new int[n];
        int[] b = new int[n];

        for (int i = 0; i < n; i++) {
            a[i] = sc.nextInt();
        }
        for (int i = 0; i < n; i++) {
            b[i] = sc.nextInt();
        }

        for (int i = 0; i < n; i++) {
            nodes.add(new Node(a[i], b[i], i + 1));
        }

        Collections.sort(nodes);

        List<Integer> result = new ArrayList<>();
        result.add(nodes.get(0).id);

        for (int i = 1; i + 1 < n; i += 2) {
            if (nodes.get(i).b >= nodes.get(i + 1).b) {
                result.add(nodes.get(i).id);
            } else {
                result.add(nodes.get(i + 1).id);
            }
        }

        if (n % 2 == 0) {
            result.add(nodes.get(n - 1).id);
        }

        Collections.sort(result);

        System.out.println(result.size());
        for (int i = 0; i < result.size(); i++) {
            System.out.print(result.get(i) + (i == result.size() - 1 ? "" : " "));
        }
        System.out.println();
    }
}
import sys

def main():
    n = int(sys.stdin.readline())
    a = list(map(int, sys.stdin.readline().split()))
    b = list(map(int, sys.stdin.readline().split()))

    nodes = []
    for i in range(n):
        nodes.append({'a': a[i], 'b': b[i], 'id': i + 1})

    nodes.sort(key=lambda x: x['a'], reverse=True)

    result = []
    result.append(nodes[0]['id'])

    i = 1
    while i + 1 < n:
        if nodes[i]['b'] >= nodes[i + 1]['b']:
            result.append(nodes[i]['id'])
        else:
            result.append(nodes[i + 1]['id'])
        i += 2
    
    if n % 2 == 0:
        result.append(nodes[n - 1]['id'])
        
    result.sort()

    print(len(result))
    print(*result)

if __name__ == "__main__":
    main()

算法及复杂度

  • 算法:贪心、排序
  • 时间复杂度,主要开销在于对包含 个元素的数组进行排序。
  • 空间复杂度,需要额外的空间来存储每个元素的值及其原始下标。