题解:BISHI91 拼接木棍

题目链接

拼接木棍

题目描述

给定 段小木棍长度,原始有若干根同长大木棍,将其切割后得到这些小木棍。求原始大木棍的最小可能长度。

解题思路

设所有小木棍长度总和为 ,最长小木棍为 。原始大木棍长度 必须满足 。按 从小到大(即从 )枚举所有能整除 的长度,检验能否将所有小木棍恰好分成 组,每组长度和为

检验过程使用回溯搜索(经典“拼木棍/还原木棍”问题):

  • 将小木棍按长度从大到小排序,优先放长棍减少分支。
  • 递归尝试把当前未完成的一根大木棍(剩余长度为 )用若干小木棍拼满;若拼满则开始下一根。
  • 剪枝:
    • 同层跳过等长小木棍的重复尝试(避免等价状态)。
    • 若当前是新开一根大木棍()且放入某根小木棍失败,则直接回溯(因为此根作为首根都失败了)。
    • 若某根小木棍恰好等于 且失败,也可直接回溯(后续更短棍也无解)。

当某个 通过检验即为答案的最小值。

代码

#include <bits/stdc++.h>
using namespace std;
using i64 = long long;

int n;
vector<int> a;
vector<char> used;
int targetL, groups;

bool dfs(int built, int rem, int start) {
    if (built == groups) return true;
    if (rem == 0) return dfs(built + 1, targetL, 0);
    int prev = -1;
    for (int i = start; i < n; ++i) {
        if (used[i]) continue;
        int x = a[i];
        if (x == prev) continue;           // 同层去重
        if (x > rem) continue;
        used[i] = 1;
        if (dfs(built, rem - x, i + 1)) return true;
        used[i] = 0;
        prev = x;
        if (rem == targetL) return false;  // 新棍首段失败,剪枝
        if (x == rem) return false;        // 恰好填满仍失败,剪枝
    }
    return false;
}

int main() {
    ios::sync_with_stdio(false);
    cin.tie(nullptr);

    if (!(cin >> n)) return 0;
    a.resize(n);
    i64 S = 0; int Lmax = 0;
    for (int i = 0; i < n; ++i) { cin >> a[i]; S += a[i]; Lmax = max(Lmax, a[i]); }
    sort(a.begin(), a.end(), greater<int>());

    for (int L = Lmax; L <= S; ++L) {
        if (S % L != 0) continue;
        targetL = L; groups = (int)(S / L);
        used.assign(n, 0);
        if (dfs(0, targetL, 0)) { cout << L << '\n'; return 0; }
    }
    return 0;
}
import java.io.*;
import java.util.*;

public class Main {
    static class FastScanner {
        private final InputStream in; private final byte[] buf = new byte[1<<16];
        private int p=0,l=0; FastScanner(InputStream is){in=is;}
        private int read() throws IOException { if (p>=l){ l=in.read(buf); p=0; if (l<=0) return -1;} return buf[p++]; }
        int nextInt() throws IOException { int c; int s=1, x=0; do{c=read();}while(c<=32); if(c=='-'){s=-1;c=read();} while(c>32){ x = x*10 + (c-'0'); c=read(); } return x*s; }
    }

    static int n, targetL, groups;
    static int[] a;
    static boolean[] used;

    static boolean dfs(int built, int rem, int start){
        if (built == groups) return true;
        if (rem == 0) return dfs(built+1, targetL, 0);
        int prev = -1;
        for (int i = start; i < n; i++){
            if (used[i]) continue;
            int x = a[i];
            if (x == prev) continue;
            if (x > rem) continue;
            used[i] = true;
            if (dfs(built, rem - x, i + 1)) return true;
            used[i] = false;
            prev = x;
            if (rem == targetL) return false;
            if (x == rem) return false;
        }
        return false;
    }

    public static void main(String[] args) throws Exception {
        FastScanner fs = new FastScanner(System.in);
        n = fs.nextInt();
        a = new int[n];
        long S = 0; int Lmax = 0;
        for (int i = 0; i < n; i++){ a[i] = fs.nextInt(); S += a[i]; Lmax = Math.max(Lmax, a[i]); }
        Integer[] ord = new Integer[n];
        for (int i = 0; i < n; i++) ord[i] = i;
        Arrays.sort(ord, (i, j) -> Integer.compare(a[j], a[i]));
        int[] b = new int[n];
        for (int i = 0; i < n; i++) b[i] = a[ord[i]];
        a = b;
        for (int L = Lmax; L <= S; L++){
            if (S % L != 0) continue;
            targetL = L; groups = (int)(S / L);
            used = new boolean[n];
            if (dfs(0, targetL, 0)) { System.out.println(L); return; }
        }
    }
}
import sys
sys.setrecursionlimit(1 << 20)

data = sys.stdin.buffer.read().split()
it = iter(data)
n = int(next(it))
a = [int(next(it)) for _ in range(n)]
a.sort(reverse=True)

S = sum(a)
Lmax = a[0]
used = [False]*n

def dfs(built: int, rem: int, start: int) -> bool:
    if built == groups:
        return True
    if rem == 0:
        return dfs(built+1, targetL, 0)
    prev = -1
    i = start
    while i < n:
        if not used[i]:
            x = a[i]
            if x != prev and x <= rem:
                used[i] = True
                if dfs(built, rem - x, i + 1):
                    return True
                used[i] = False
                prev = x
                if rem == targetL:
                    return False
                if x == rem:
                    return False
        i += 1
    return False

for L in range(Lmax, S+1):
    if S % L != 0:
        continue
    targetL = L
    groups = S // L
    used = [False]*n
    if dfs(0, targetL, 0):
        print(L)
        break

算法及复杂度

  • 算法:枚举原始长度 (需整除 ),回溯分组并配合剪枝
  • 时间复杂度:最坏情况指数级,剪枝后实际表现优秀;排序
  • 空间复杂度:(标记数组与递归栈)