题解:BISHI40 数组取精

题目链接

数组取精

题目描述

给定两个长度为 的正整数序列 。 定义子集 ,元素互异)为“精华子集”,当且仅当同时满足:

请构造任意一个满足条件的精华子集并输出其大小及下标。

解题思路

分块配对 + 贪心:

  • 将下标按 降序排序为 ,并将其两两配成区间对
  • ,维护当前选中和
  • 依次处理每一对 :分别试选 ,比较二者对“剩余缺口” 的降低量,取更优者;若相同,优先增大较大缺口的那一维。
  • 为奇数,末尾还会有一个单点,直接加入。

这样在每个配对块上都做出对两维半和同时有利的选择,最终得到 的集合,且两维均严格大于全体的一半。

代码

#include <bits/stdc++.h>
using namespace std;

int main() {
    ios::sync_with_stdio(false);
    cin.tie(nullptr);

    int n;
    if (!(cin >> n)) return 0;
    vector<long long> a(n+1), b(n+1);
    long long sumA = 0, sumB = 0;
    for (int i = 1; i <= n; ++i) { cin >> a[i]; sumA += a[i]; }
    for (int i = 1; i <= n; ++i) { cin >> b[i]; sumB += b[i]; }
    vector<int> ord(n);
    iota(ord.begin(), ord.end(), 1);
    sort(ord.begin(), ord.end(), [&](int i, int j){ return (a[i]-b[i]) > (a[j]-b[j]); });
    long long needA = sumA / 2 + 1, needB = sumB / 2 + 1;
    long long sa = 0, sb = 0;
    vector<char> picked(n+1, 0);
    vector<int> sel;
    int L = 0, R = n - 1;
    auto deficit = [&](long long SA, long long SB) -> long long {
        long long da = needA - SA; if (da < 0) da = 0;
        long long db = needB - SB; if (db < 0) db = 0;
        return da + db;
    };
    while (L < R && (sa < needA || sb < needB)) {
        int u = ord[L], v = ord[R];
        long long costU = deficit(sa + a[u], sb + b[u]);
        long long costV = deficit(sa + a[v], sb + b[v]);
        bool takeU;
        if (costU != costV) takeU = (costU < costV);
        else {
            long long gapA = max(0LL, needA - sa);
            long long gapB = max(0LL, needB - sb);
            if (gapA >= gapB) takeU = (a[u] >= a[v]);
            else takeU = (b[u] >= b[v]);
        }
        if (takeU) { sel.push_back(u); picked[u] = 1; sa += a[u]; sb += b[u]; ++L; }
        else { sel.push_back(v); picked[v] = 1; sa += a[v]; sb += b[v]; --R; }
    }
    if (sa < needA || sb < needB) {
        // one more allowed up to floor(n/2)+1
        int best = -1; long long bestCost = LLONG_MAX;
        for (int idx : ord) if (!picked[idx]) {
            long long c = deficit(sa + a[idx], sb + b[idx]);
            if (c < bestCost) { bestCost = c; best = idx; }
        }
        if (best != -1) {
            sel.push_back(best); sa += a[best]; sb += b[best]; picked[best] = 1;
        }
    }
    cout << (int)sel.size() << '\n';
    for (int i = 0; i < (int)sel.size(); ++i) {
        if (i) cout << ' ';
        cout << sel[i];
    }
    cout << '\n';
    return 0;
}
import java.io.*;
import java.util.*;

public class Main {
    static class FastScanner {
        private final InputStream in;
        private final byte[] buffer = new byte[1 << 16];
        private int ptr = 0, len = 0;
        FastScanner(InputStream is) { this.in = is; }
        private int read() throws IOException {
            if (ptr >= len) { len = in.read(buffer); ptr = 0; if (len <= 0) return -1; }
            return buffer[ptr++];
        }
        long nextLong() throws IOException {
            int c; long sgn = 1, x = 0;
            do { c = read(); } while (c <= 32);
            if (c == '-') { sgn = -1; c = read(); }
            while (c > 32) { x = x * 10 + (c - '0'); c = read(); }
            return x * sgn;
        }
        int nextInt() throws IOException { return (int) nextLong(); }
    }

    public static void main(String[] args) throws Exception {
        FastScanner fs = new FastScanner(System.in);
        int n = fs.nextInt();
        long[] a = new long[n + 1];
        long[] b = new long[n + 1];
        for (int i = 1; i <= n; i++) a[i] = fs.nextLong();
        for (int i = 1; i <= n; i++) b[i] = fs.nextLong();
        Integer[] ord = new Integer[n];
        for (int i = 0; i < n; i++) ord[i] = i + 1;
        Arrays.sort(ord, (i, j) -> Long.compare((a[j] - b[j]), (a[i] - b[i])));

        long sumA = 0, sumB = 0;
        for (int i = 1; i <= n; i++) { sumA += a[i]; sumB += b[i]; }
        long needA = sumA / 2 + 1, needB = sumB / 2 + 1;
        long sa = 0, sb = 0;
        boolean[] picked = new boolean[n + 1];
        ArrayList<Integer> sel = new ArrayList<>();
        int L = 0, R = n - 1;
        while (L < R && (sa < needA || sb < needB)) {
            int u = ord[L], v = ord[R];
            long defU = Math.max(0L, needA - (sa + a[u])) + Math.max(0L, needB - (sb + b[u]));
            long defV = Math.max(0L, needA - (sa + a[v])) + Math.max(0L, needB - (sb + b[v]));
            boolean takeU;
            if (defU != defV) takeU = defU < defV;
            else {
                long gapA = Math.max(0L, needA - sa), gapB = Math.max(0L, needB - sb);
                takeU = (gapA >= gapB) ? (a[u] >= a[v]) : (b[u] >= b[v]);
            }
            if (takeU) { sel.add(u); picked[u] = true; sa += a[u]; sb += b[u]; L++; }
            else { sel.add(v); picked[v] = true; sa += a[v]; sb += b[v]; R--; }
        }
        if (sa < needA || sb < needB) {
            int best = -1; long bestCost = Long.MAX_VALUE;
            for (int idx : ord) if (!picked[idx]) {
                long cost = Math.max(0L, needA - (sa + a[idx])) + Math.max(0L, needB - (sb + b[idx]));
                if (cost < bestCost) { bestCost = cost; best = idx; }
            }
            if (best != -1) { sel.add(best); sa += a[best]; sb += b[best]; picked[best] = true; }
        }
        StringBuilder out = new StringBuilder();
        out.append(sel.size()).append('\n');
        for (int i = 0; i < sel.size(); i++) {
            if (i > 0) out.append(' ');
            out.append(sel.get(i));
        }
        out.append('\n');
        System.out.print(out.toString());
    }
}
import sys

data = sys.stdin.buffer.read().split()
it = iter(data)
n = int(next(it))
a = [0] * (n + 1)
b = [0] * (n + 1)
for i in range(1, n + 1):
    a[i] = int(next(it))
for i in range(1, n + 1):
    b[i] = int(next(it))
ord_idx = list(range(1, n + 1))
ord_idx.sort(key=lambda i: (a[i] - b[i]), reverse=True)

sumA = sum(a[1:])
sumB = sum(b[1:])
needA = sumA // 2 + 1
needB = sumB // 2 + 1
sa = sb = 0
picked = [False] * (n + 1)
sel = []
L, R = 0, n - 1
def deficit(SA, SB):
    da = needA - SA
    db = needB - SB
    if da < 0: da = 0
    if db < 0: db = 0
    return da + db
while L < R and (sa < needA or sb < needB):
    u = ord_idx[L]; v = ord_idx[R]
    defU = deficit(sa + a[u], sb + b[u])
    defV = deficit(sa + a[v], sb + b[v])
    if defU != defV:
        takeU = defU < defV
    else:
        gapA = max(0, needA - sa)
        gapB = max(0, needB - sb)
        takeU = (a[u] >= a[v]) if gapA >= gapB else (b[u] >= b[v])
    if takeU:
        sel.append(u); picked[u] = True; sa += a[u]; sb += b[u]; L += 1
    else:
        sel.append(v); picked[v] = True; sa += a[v]; sb += b[v]; R -= 1
if sa < needA or sb < needB:
    best = -1; bestCost = 10**30
    for idx in ord_idx:
        if not picked[idx]:
            c = deficit(sa + a[idx], sb + b[idx])
            if c < bestCost:
                bestCost = c; best = idx
    if best != -1:
        sel.append(best); sa += a[best]; sb += b[best]; picked[best] = True
print(len(sel))
print(' '.join(map(str, sel)))

算法及复杂度

  • 算法:按 排序 + 双指针交替选取
  • 时间复杂度:(排序),输出
  • 空间复杂度: