题目链接

混乱的奶牛

题目描述

头奶牛,每头奶牛都有一个唯一的整数编号。如果一个排列中,任意相邻两头奶牛的编号之差的绝对值都严格大于 ,则称这个排列为“混乱队伍”。

给定 头奶牛的编号和整数 ,计算有多少种不同的排列是混乱队伍。

解题思路

这是一个带有特定约束条件的排列计数问题。鉴于奶牛的数量 非常小(),这强烈地暗示了一个指数级复杂度的算法。解决这类问题的经典方法是状态压缩动态规划 (DP on Subsets)

我们可以将这个问题建模为:在一个图中寻找哈密尔顿路径的数量。图的每个顶点代表一头奶牛,如果两头奶牛 的编号之差 abs(c[i] - c[j]) > k,则它们之间存在一条边。问题就转化为,在这个图上,有多少条经过所有 个顶点的简单路径。

算法步骤

  1. 预处理:

    • 首先,读入所有奶牛的编号。对编号进行排序不是必需的,但通常是一个好习惯,可以使问题结构更清晰。
  2. DP 状态定义: 我们定义 dp[mask][i] 为:

    • mask 是一个二进制整数,表示当前已参与排列的奶牛集合。如果 mask 的第 j 位是 1,代表第 j 头奶牛已被使用。
    • i 表示这个排列的最后一头奶牛是第 i 头。
    • dp[mask][i] 的值就是满足上述条件的有效排列(混乱队伍)的数量。
  3. DP 初始化:

    • 对于每头奶牛 i,它自身可以构成一个长度为 1 的排列。
    • 因此,我们初始化 dp[1 << i][i] = 1,表示只包含奶牛 i 的集合,且排列以 i 结尾的方案数为 1。
  4. DP 转移:

    • 我们按 mask 从小到大的顺序进行迭代。
    • 对于一个给定的 mask 和其中的终点 i,如果 dp[mask][i] > 0,说明我们已经有了一个有效的局部排列。
    • 我们可以尝试在这个排列的末尾添加一头新的奶牛 j
    • 遍历所有尚未mask 中的奶牛 j
    • 如果奶牛 j 可以合法地接在奶牛 i 的后面(即 abs(c[i] - c[j]) > k),我们就可以进行状态转移: dp[mask | (1 << j)][j] += dp[mask][i]
    • 这表示,所有以 i 结尾、包含 mask 集合的有效排列,都可以通过在末尾加上 j,形成一个新的以 j 结尾、包含 mask | (1 << j) 集合的有效排列。
  5. 计算结果:

    • 当 DP 表格填充完毕后,所有奶牛都已排列(即 mask(1 << n) - 1)的情况就是我们要求的最终排列。
    • 最终答案是所有 dp[(1 << n) - 1][i] 的总和(其中 i0n-1),因为任何一头奶牛都可以作为排列的结尾。

这种算法的时间复杂度为 ,对于 是完全可以接受的。

代码

#include <iostream>
#include <vector>
#include <cmath>
#include <numeric>
#include <algorithm>

using namespace std;

int main() {
    ios_base::sync_with_stdio(false);
    cin.tie(NULL);

    int n, k;
    cin >> n >> k;

    vector<int> c(n);
    for (int i = 0; i < n; ++i) {
        cin >> c[i];
    }
    
    vector<vector<long long>> dp(1 << n, vector<long long>(n, 0));

    for (int i = 0; i < n; ++i) {
        dp[1 << i][i] = 1;
    }

    for (int mask = 1; mask < (1 << n); ++mask) {
        for (int i = 0; i < n; ++i) {
            if ((mask >> i) & 1) { // If cow i is in the current set
                if (dp[mask][i] > 0) {
                    for (int j = 0; j < n; ++j) {
                        if (!((mask >> j) & 1)) { // If cow j is not in the set
                            if (abs(c[i] - c[j]) > k) {
                                int next_mask = mask | (1 << j);
                                dp[next_mask][j] += dp[mask][i];
                            }
                        }
                    }
                }
            }
        }
    }

    long long total_permutations = 0;
    for (int i = 0; i < n; ++i) {
        total_permutations += dp[(1 << n) - 1][i];
    }

    cout << total_permutations << endl;

    return 0;
}
import java.util.Scanner;
import java.util.Arrays;

public class Main {
    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);
        int n = sc.nextInt();
        int k = sc.nextInt();

        int[] c = new int[n];
        for (int i = 0; i < n; i++) {
            c[i] = sc.nextInt();
        }

        long[][] dp = new long[1 << n][n];

        for (int i = 0; i < n; i++) {
            dp[1 << i][i] = 1;
        }

        for (int mask = 1; mask < (1 << n); mask++) {
            for (int i = 0; i < n; i++) {
                if (((mask >> i) & 1) == 1) { // If cow i is in the current set
                    if (dp[mask][i] > 0) {
                        for (int j = 0; j < n; j++) {
                            if (((mask >> j) & 1) == 0) { // If cow j is not in the set
                                if (Math.abs(c[i] - c[j]) > k) {
                                    int nextMask = mask | (1 << j);
                                    dp[nextMask][j] += dp[mask][i];
                                }
                            }
                        }
                    }
                }
            }
        }

        long totalPermutations = 0;
        for (int i = 0; i < n; i++) {
            totalPermutations += dp[(1 << n) - 1][i];
        }

        System.out.println(totalPermutations);
    }
}
import sys

def main():
    try:
        input = sys.stdin.readline
        n, k = map(int, input().split())
        c = [int(input()) for _ in range(n)]
        
        dp = [[0] * n for _ in range(1 << n)]

        for i in range(n):
            dp[1 << i][i] = 1
        
        for mask in range(1, 1 << n):
            for i in range(n):
                if (mask >> i) & 1: # If cow i is in the current set
                    if dp[mask][i] > 0:
                        for j in range(n):
                            if not ((mask >> j) & 1): # If cow j is not in the set
                                if abs(c[i] - c[j]) > k:
                                    next_mask = mask | (1 << j)
                                    dp[next_mask][j] += dp[mask][i]

        total_permutations = sum(dp[(1 << n) - 1])
        sys.stdout.write(str(total_permutations) + '\n')

    except (IOError, ValueError):
        return

main()

算法及复杂度

  • 算法:状态压缩动态规划 (DP on Subsets)
  • 时间复杂度:。我们有 个状态,每个状态的转移需要 的时间。
  • 空间复杂度:,用于存储 DP 表。