题目链接

拼接木棍

题目描述

给定 根小木棍的长度,这些小木棍是由若干根等长的大木棍截断而来的。你需要找出这些大木棍原始的、最小的可能长度。

解题思路

  1. 问题转化与搜索空间

    这个问题可以转化为一个搜索问题:我们尝试每一种可能的“原始长度” ,然后验证是否能用所有小木棍恰好拼接成若干根长度为 的大木棍。我们要求的,就是能满足条件的最小的

    的可能取值范围是什么?

    • 下界 必须至少等于最长的那根小木棍的长度。
    • 上界:在最极端的情况下,所有小木棍都来自同一根大木棍,所以 不会超过所有小木棍的总长度。
    • 约束:所有小木棍的总长度 total_sum 必须能被 整除。

    因此,我们的整体算法是:

    1. 计算所有小木棍的总长度 total_sum 和最长小木棍的长度 max_len
    2. max_len 开始,向上枚举到 total_sum,找到第一个满足 total_sum % L == 0 并且可以通过验证的 ,这个 就是答案。
  2. 验证函数:深度优先搜索 (DFS) + 剪枝

    验证的核心在于,给定一个目标长度 和所有小木棍,判断是否能将小木棍不重不漏地分成 total_sum / L 组,每组的和都恰好为 。这是一个经典的组合搜索问题,可以通过深度优先搜索(DFS)结合回溯来解决。

    然而,朴素的DFS会因为巨大的搜索空间而超时。为了通过本题,必须进行强有力的剪枝来优化搜索过程。

  3. 核心剪枝策略

    1. 降序排序:将所有小木棍按长度从大到小排序。在搜索时,优先尝试用较长的小木棍去拼接。这样做的好处是,如果一个组合不成立,可以更快地“撞南墙”,从而尽早回溯,大幅减少无效搜索。

    2. 跳过重复元素:在搜索的同一层,如果尝试了小木棍 sticks[i] 失败了,那么就应该跳过所有与 sticks[i] 长度相同的后续小木棍。因为使用这些等长的小木棍会进入完全一样的子问题,必然也会失败。

    3. 新木棍的“第一根”剪枝:当我们开始拼接一根新的大木棍时(即当前拼接长度为0),如果尝试的第一根小木棍 sticks[i] 最终无法构成一个成功的方案(即以 sticks[i] 开头的这根大木棍拼好后,剩下的木棍无法成功分组),那么整个搜索可以直接宣告失败。

      • 原因:由于木棍已降序排序,sticks[i] 是当前可用的最长木棍。如果连它都无法成为任何一个成功方案的一部分,那么换成任何更短的木棍 sticks[j] (j>i) 来开头,只会让剩下的木棍组合变得更“困难”(因为留下了更长的 sticks[i]),同样不可能成功。
    4. 当前木棍的“最后一根”剪枝:当我们为当前正在拼接的大木棍放入一根小木棍 sticks[i] 后,恰好使其长度达到了目标 ,此时我们转入对剩余小木棍的递归搜索。如果这个后续的搜索失败了,那么本次尝试也直接宣告失败。

      • 原因:如果用 sticks[i] 作为“最后一根”木棍完成当前大木棍的拼接,而剩下的木棍无法成功分组,那么换用任何其他更短的木棍 sticks[j] (j>i) 作为“最后一根”来完成拼接,虽然也能凑成长度 ,但却会把更长的 sticks[i] 留给后续,使得后续拼接问题变得“更容易”。如果连“更容易”的情况都无法成功,那么我们最初用 sticks[i] 结尾导致失败的情况就更不可能有解了。因此,可以直接断定当前分支失败。

代码

#include <iostream>
#include <vector>
#include <numeric>
#include <algorithm>
#include <functional>

using namespace std;

int n;
vector<int> sticks;
vector<bool> visited;
int target_len;
int num_groups;

// count: 已拼好的大木棍数量
// current_sum: 当前正在拼的木棍的长度
// start_idx: 从哪根小木棍开始尝试
bool dfs(int count, int current_sum, int start_idx) {
    if (count == num_groups) {
        return true;
    }

    if (current_sum == target_len) {
        return dfs(count + 1, 0, 0);
    }

    for (int i = start_idx; i < n; ++i) {
        if (visited[i]) {
            continue;
        }
        if (current_sum + sticks[i] > target_len) {
            continue;
        }

        visited[i] = true;
        if (dfs(count, current_sum + sticks[i], i + 1)) {
            return true;
        }
        visited[i] = false;

        // 剪枝3: 新木棍的“第一根”
        if (current_sum == 0) {
            return false;
        }
        // 剪枝4: 当前木棍的“最后一根”
        if (current_sum + sticks[i] == target_len) {
            return false;
        }
        // 剪枝2: 跳过重复
        while (i + 1 < n && sticks[i] == sticks[i + 1]) {
            i++;
        }
    }

    return false;
}

int main() {
    ios_base::sync_with_stdio(false);
    cin.tie(NULL);

    while (cin >> n && n != 0) {
        sticks.resize(n);
        int total_sum = 0;
        int max_len = 0;
        for (int i = 0; i < n; ++i) {
            cin >> sticks[i];
            total_sum += sticks[i];
            max_len = max(max_len, sticks[i]);
        }

        sort(sticks.begin(), sticks.end(), greater<int>());

        for (int len = max_len; len <= total_sum; ++len) {
            if (total_sum % len == 0) {
                target_len = len;
                num_groups = total_sum / len;
                visited.assign(n, false);
                if (dfs(0, 0, 0)) {
                    cout << len << endl;
                    break;
                }
            }
        }
    }

    return 0;
}
import java.util.Scanner;
import java.util.Arrays;
import java.util.Collections;

public class Main {
    private static int n;
    private static Integer[] sticks;
    private static boolean[] visited;
    private static int targetLen;
    private static int numGroups;

    private static boolean dfs(int count, int currentSum, int startIdx) {
        if (count == numGroups) {
            return true;
        }
        if (currentSum == targetLen) {
            return dfs(count + 1, 0, 0);
        }

        for (int i = startIdx; i < n; i++) {
            if (visited[i]) {
                continue;
            }
            if (currentSum + sticks[i] > targetLen) {
                continue;
            }

            visited[i] = true;
            if (dfs(count, currentSum + sticks[i], i + 1)) {
                return true;
            }
            visited[i] = false;

            if (currentSum == 0) return false;
            if (currentSum + sticks[i] == targetLen) return false;
            
            while (i + 1 < n && sticks[i].equals(sticks[i + 1])) {
                i++;
            }
        }
        return false;
    }

    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);
        while (sc.hasNextInt()) {
            n = sc.nextInt();
            if (n == 0) break;

            sticks = new Integer[n];
            int totalSum = 0;
            int maxLen = 0;
            for (int i = 0; i < n; i++) {
                sticks[i] = sc.nextInt();
                totalSum += sticks[i];
                maxLen = Math.max(maxLen, sticks[i]);
            }

            Arrays.sort(sticks, Collections.reverseOrder());

            for (int len = maxLen; len <= totalSum; len++) {
                if (totalSum % len == 0) {
                    targetLen = len;
                    numGroups = totalSum / len;
                    visited = new boolean[n];
                    if (dfs(0, 0, 0)) {
                        System.out.println(len);
                        break;
                    }
                }
            }
        }
    }
}
import sys

def dfs(count, current_sum, start_idx):
    if count == num_groups:
        return True
    
    if current_sum == target_len:
        return dfs(count + 1, 0, 0)

    i = start_idx
    while i < n:
        if visited[i]:
            i += 1
            continue
        if current_sum + sticks[i] > target_len:
            i += 1
            continue

        visited[i] = True
        if dfs(count, current_sum + sticks[i], i + 1):
            return True
        visited[i] = False

        if current_sum == 0:
            return False
        if current_sum + sticks[i] == target_len:
            return False
            
        # 跳过重复
        temp = sticks[i]
        while i < n and sticks[i] == temp:
            i += 1
    
    return False

def solve():
    global n, sticks, visited, target_len, num_groups
    
    lines = sys.stdin.readlines()
    input_idx = 0
    while input_idx < len(lines):
        n = int(lines[input_idx].strip())
        if n == 0:
            break
        input_idx += 1
        
        sticks = list(map(int, lines[input_idx].strip().split()))
        input_idx += 1
        
        total_sum = sum(sticks)
        max_len = max(sticks)
        
        sticks.sort(reverse=True)
        
        for length in range(max_len, total_sum + 1):
            if total_sum % length == 0:
                target_len = length
                num_groups = total_sum // length
                visited = [False] * n
                if dfs(0, 0, 0):
                    print(length)
                    break

# sys.setrecursionlimit(1000) # 可能需要根据题目调整
solve()

算法及复杂度

  • 算法:枚举 + 深度优先搜索 (DFS) + 剪枝
  • 时间复杂度:难以精确分析,因为剪枝的效果非常显著。在最坏情况下,搜索是指数级的,但由于强大的剪枝,对于此类问题的典型数据,该算法是高效的。
  • 空间复杂度,主要由递归栈的深度和小木棍数组 sticksvisited 占用。