题目链接

礼品派发

题目描述

给定一个包含n个礼品价格的数组 prices 和一个粉丝人数 k。需要判断是否能将所有礼品全部分配给 k 位粉丝,并使得每位粉丝获得的礼品总价值完全相等。

思路分析

这是一个经典的K个相等子集划分问题,属于NP难问题,通常使用回溯搜索(DFS)算法来解决。

首先,我们可以进行一些基本的可行性判断来提前排除无解的情况:

  1. 计算所有礼品的总价值 total_sum。如果 total_sum 不能被粉丝数 k 整除,那么无论如何都不可能公平分配,可以直接返回 false
  2. 如果可以整除,我们可以计算出每位粉丝应该获得的礼品总价值 target = total_sum / k

之后,问题就转化为:能否从 prices 数组中找到 k 个不相交的子集,使得每个子集的和都等于 target

我们采用深度优先搜索(DFS)的策略来寻找这样一种划分。基本思路是尝试为 k 位粉丝逐一凑齐价值为 target 的礼品包。为了避免因搜索空间过大而超时,我们需要进行剪枝优化:

  1. 对价格排序:将 prices 数组从大到小排序。这是一个非常有效的剪枝策略。在搜索时优先尝试分配价值较大的礼品,如果发现某个大礼品无法被分配,就可以更早地发现此路不通并回溯,从而避免大量无效的搜索。
  2. 提前剪枝:排序后,如果最大的礼品价格 prices[0] 就已经大于 target,那么显然不可能完成分配,可以直接返回 false
  3. 去重剪枝:如果在搜索的某一层,我们尝试了 prices[i] 发现无法成功,并且 prices[i+1] 的价格与 prices[i] 相同,那么我们就可以跳过 prices[i+1],因为它必然也会导致同样失败的结果,从而避免重复搜索。

我们的递归函数 dfs(k, current_sum, start_index) 的含义是:尝试在尚未分配的礼品中,为当前粉丝凑齐一个礼品包。

  • k 表示还剩下多少位粉丝需要分配。
  • current_sum 是当前粉丝已选礼品的总价。
  • start_index 是本次从哪个礼品开始尝试。

current_sum 达到 target 时,说明一个粉丝的礼品包已凑齐。然后我们开始为下一个粉丝凑包,即调用 dfs(k-1, 0, 0)。当 k 最终减为0时,说明所有粉丝都分配完毕,返回 true

代码

#include <vector>
#include <numeric>
#include <algorithm>
#include <functional>

using namespace std;

class Solution {
public:
    bool canEqualDistribution(vector<int>& prices, int k) {
        int sum = accumulate(prices.begin(), prices.end(), 0);
        if (sum % k != 0) {
            return false;
        }
        int target = sum / k;
        
        sort(prices.begin(), prices.end(), greater<int>());
        if (prices[0] > target) {
            return false;
        }
        
        vector<bool> used(prices.size(), false);
        return dfs(prices, used, k, 0, 0, target);
    }

private:
    bool dfs(const vector<int>& prices, vector<bool>& used, int k, int current_sum, int start_index, int target) {
        if (k == 0) {
            return true;
        }
        if (current_sum == target) {
            return dfs(prices, used, k - 1, 0, 0, target);
        }

        for (int i = start_index; i < prices.size(); ++i) {
            if (used[i]) {
                continue;
            }
            if (i > 0 && prices[i] == prices[i - 1] && !used[i - 1]) {
                continue;
            }
            if (current_sum + prices[i] <= target) {
                used[i] = true;
                if (dfs(prices, used, k, current_sum + prices[i], i + 1, target)) {
                    return true;
                }
                used[i] = false;
            }
        }
        return false;
    }
};
import java.util.Arrays;
import java.util.Collections;
import java.util.stream.IntStream;

class Solution {
    public boolean canEqualDistribution(int[] prices, int k) {
        int sum = IntStream.of(prices).sum();
        if (sum % k != 0) {
            return false;
        }
        int target = sum / k;

        // Java primitive array needs boxing for reverse order sort
        Integer[] prices_Integer = Arrays.stream(prices).boxed().toArray(Integer[]::new);
        Arrays.sort(prices_Integer, Collections.reverseOrder());
        
        if (prices_Integer[0] > target) {
            return false;
        }

        boolean[] used = new boolean[prices.length];
        return dfs(prices_Integer, used, k, 0, 0, target);
    }

    private boolean dfs(Integer[] prices, boolean[] used, int k, int currentSum, int startIndex, int target) {
        if (k == 0) {
            return true;
        }
        if (currentSum == target) {
            return dfs(prices, used, k - 1, 0, 0, target);
        }

        for (int i = startIndex; i < prices.length; i++) {
            if (used[i]) {
                continue;
            }
            if (i > 0 && prices[i].equals(prices[i - 1]) && !used[i - 1]) {
                continue;
            }
            if (currentSum + prices[i] <= target) {
                used[i] = true;
                if (dfs(prices, used, k, currentSum + prices[i], i + 1, target)) {
                    return true;
                }
                used[i] = false;
            }
        }
        return false;
    }
}
class Solution:
    def canEqualDistribution(self, prices, k):
        total_sum = sum(prices)
        if total_sum % k != 0:
            return False
        target = total_sum // k
        
        prices.sort(reverse=True)
        if prices[0] > target:
            return False
            
        used = [False] * len(prices)

        def dfs(k_rem, current_sum, start_index):
            if k_rem == 0:
                return True
            if current_sum == target:
                return dfs(k_rem - 1, 0, 0)
            
            for i in range(start_index, len(prices)):
                if used[i]:
                    continue
                # 去重剪枝
                if i > 0 and prices[i] == prices[i-1] and not used[i-1]:
                    continue
                if current_sum + prices[i] <= target:
                    used[i] = True
                    if dfs(k_rem, current_sum + prices[i], i + 1):
                        return True
                    used[i] = False
            return False

        return dfs(k, 0, 0)

算法及复杂度

  • 算法:回溯搜索 (DFS) + 剪枝
  • 时间复杂度:难以精确分析,最坏情况下为指数级 。但由于大量的剪枝操作,实际运行效率会远高于理论最坏情况,足以通过通常较弱的测试数据。
  • 空间复杂度,主要由递归栈的深度和 used 数组所占用。