题目链接

遥控无人机

题目描述

初始时,无人机位于原点 ,目标是到达点 。给定 条指令,每条指令包含一个位移向量 。对于所有可能的 ,需要计算出有多少种方法,可以从 条指令中恰好选择 条,使得执行后无人机正好到达目标点。

解题思路

本题要求我们对每个可能的指令数量 ,找出选择 条指令使得位移总和为 的方案数。这是一个典型的组合计数问题,与子集和问题相关。

暴力方法与DP的局限性

一个朴素的想法是检查所有 个子集,但这对于 来说太慢了。

另一个思路是动态规划。我们可以定义一个状态 dp[i][j][x][y] 表示从前 条指令中选择 条,使得总位移为 的方案数。然而,位移坐标 的范围可能很大(可达 ),导致状态空间过大,无法在内存和时间限制内完成。

中间相遇 (Meet-in-the-Middle)

的范围在 40-50 左右时,一个经典的优化策略是“中间相遇”算法。该算法的核心思想是将问题一分为二,分别计算两半的结果,然后合并它们。

  1. 分割指令集 我们将 条指令分成两半:

    • 第一组:前 条指令。
    • 第二组:后 条指令。
  2. 暴力枚举第一组 我们使用深度优先搜索(DFS)来遍历第一组指令的所有可能组合(即所有子集)。对于每个子集,我们记录下所选指令的数量 ,以及产生的总位移

    我们将这些结果存储在一个数据结构中,例如 map1map1 是一个数组,其中 map1[k1] 是一个哈希表(map),它将位移向量 映射到其出现的次数。即 map1[k1][(x1, y1)] 表示用 条指令达到位移 的方案数。

  3. 暴力枚举第二组并合并结果 接着,我们用同样的方式(DFS)遍历第二组指令的所有组合,得到每个子集的指令数 和总位移

    对于第二组的每一个结果 ,我们要在第一组中寻找一个“互补”的结果 ,使得:

    对于每一个 ,我们遍历所有可能的 (从 )。我们在 map1[k1] 中查找键 (Tx - x2, Ty - y2)。如果找到了,说明存在匹配的方案。设查找到的方案数为 count,那么我们就为总指令数为 的最终答案 ans[k] 加上 count

    通过这种方式,我们将两半的结果有效地合并起来,得到了所有 的答案。

代码

#include <iostream>
#include <vector>
#include <map>

using namespace std;

int n;
long long tx, ty;
vector<pair<int, int>> instructions;
// 优化后的数据结构:(x, y) -> (k1 -> count)
map<pair<long long, long long>, map<int, long long>> half1_results;
vector<long long> ans;

// 搜索前半部分指令
void dfs1(int index, int count, long long current_x, long long current_y) {
    if (index == n / 2) {
        half1_results[{current_x, current_y}][count]++;
        return;
    }

    // 方案1:不选择当前指令
    dfs1(index + 1, count, current_x, current_y);

    // 方案2:选择当前指令
    dfs1(index + 1, count + 1, current_x + instructions[index].first, current_y + instructions[index].second);
}

// 搜索后半部分指令并合并结果
void dfs2(int index, int count, long long current_x, long long current_y) {
    if (index == n) {
        long long target_x = tx - current_x;
        long long target_y = ty - current_y;
        if (half1_results.count({target_x, target_y})) {
            for (auto const& [k1, num_ways] : half1_results.at({target_x, target_y})) {
                int total_k = k1 + count;
                if (total_k > 0) {
                    ans[total_k] += num_ways;
                }
            }
        }
        return;
    }

    // 方案1:不选择当前指令
    dfs2(index + 1, count, current_x, current_y);

    // 方案2:选择当前指令
    dfs2(index + 1, count + 1, current_x + instructions[index].first, current_y + instructions[index].second);
}

int main() {
    cin >> n;
    cin >> tx >> ty;
    instructions.resize(n);
    for (int i = 0; i < n; ++i) {
        cin >> instructions[i].first >> instructions[i].second;
    }

    ans.resize(n + 1, 0);

    dfs1(0, 0, 0, 0);
    dfs2(n / 2, 0, 0, 0);

    for (int k = 1; k <= n; ++k) {
        cout << ans[k] << "\n";
    }

    return 0;
}
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;
import java.util.Scanner;

public class Main {
    static int n;
    static long tx, ty;
    static int[][] instructions;
    // 优化后的数据结构:(x, y) -> (k1 -> count)
    static Map<Pair, Map<Integer, Long>> half1Results;
    static long[] ans;
    
    // 用于哈希表键的坐标对类
    static class Pair {
        long x, y;

        public Pair(long x, long y) {
            this.x = x;
            this.y = y;
        }

        @Override
        public boolean equals(Object o) {
            if (this == o) return true;
            if (o == null || getClass() != o.getClass()) return false;
            Pair pair = (Pair) o;
            return x == pair.x && y == pair.y;
        }

        @Override
        public int hashCode() {
            return Objects.hash(x, y);
        }
    }
    
    // 搜索前半部分
    static void dfs1(int index, int count, long currentX, long currentY) {
        if (index == n / 2) {
            Pair p = new Pair(currentX, currentY);
            half1Results.computeIfAbsent(p, k -> new HashMap<>())
                        .merge(count, 1L, Long::sum);
            return;
        }
        // 方案1:不选
        dfs1(index + 1, count, currentX, currentY);
        // 方案2:选
        dfs1(index + 1, count + 1, currentX + instructions[index][0], currentY + instructions[index][1]);
    }

    // 搜索后半部分
    static void dfs2(int index, int count, long currentX, long currentY) {
        if (index == n) {
            long targetX = tx - currentX;
            long targetY = ty - currentY;
            Pair targetPair = new Pair(targetX, targetY);
            if (half1Results.containsKey(targetPair)) {
                Map<Integer, Long> k1Counts = half1Results.get(targetPair);
                for (Map.Entry<Integer, Long> entry : k1Counts.entrySet()) {
                    int k1 = entry.getKey();
                    long numWays = entry.getValue();
                    int totalK = k1 + count;
                    if (totalK > 0) {
                        ans[totalK] += numWays;
                    }
                }
            }
            return;
        }
        // 方案1:不选
        dfs2(index + 1, count, currentX, currentY);
        // 方案2:选
        dfs2(index + 1, count + 1, currentX + instructions[index][0], currentY + instructions[index][1]);
    }

    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);
        n = sc.nextInt();
        tx = sc.nextLong();
        ty = sc.nextLong();
        
        instructions = new int[n][2];
        for (int i = 0; i < n; i++) {
            instructions[i][0] = sc.nextInt();
            instructions[i][1] = sc.nextInt();
        }

        half1Results = new HashMap<>();
        ans = new long[n + 1];

        dfs1(0, 0, 0, 0);
        dfs2(n / 2, 0, 0, 0);

        for (int k = 1; k <= n; k++) {
            System.out.println(ans[k]);
        }
    }
}
import sys

# 为深度递归做好准备
sys.setrecursionlimit(200005)

n = 0
tx, ty = 0, 0
instructions = []
# 优化后的数据结构:(x, y) -> {k1: count}
half1_results = {}
ans = []

# 搜索前半部分
def dfs1(idx, count, x, y):
    if idx == n // 2:
        pos = (x, y)
        if pos not in half1_results:
            half1_results[pos] = {}
        half1_results[pos][count] = half1_results[pos].get(count, 0) + 1
        return

    # 方案1:不选
    dfs1(idx + 1, count, x, y)
    
    # 方案2:选
    dx, dy = instructions[idx]
    dfs1(idx + 1, count + 1, x + dx, y + dy)

# 搜索后半部分并合并
def dfs2(idx, count, x, y):
    if idx == n:
        target_x = tx - x
        target_y = ty - y
        target_pos = (target_x, target_y)
        if target_pos in half1_results:
            for k1, num_ways in half1_results[target_pos].items():
                total_k = k1 + count
                if total_k > 0:
                    ans[total_k] += num_ways
        return

    # 方案1:不选
    dfs2(idx + 1, count, x, y)
    
    # 方案2:选
    dx, dy = instructions[idx]
    dfs2(idx + 1, count + 1, x + dx, y + dy)

def main():
    global n, tx, ty, instructions, half1_results, ans
    
    # 注意题目N<=40,但之前代码的输入格式是N,然后是tx,ty
    # 按照之前的格式进行读取
    lines = sys.stdin.readlines()
    if not lines:
        return

    for line in lines:
        n_str, tx_ty_str = line.split()
        n = int(n_str)
        tx, ty = int(tx_ty_str[0]), int(tx_ty_str[1])
        
        instructions = [list(map(int, input().split())) for _ in range(n)]
    except (IOError, ValueError):
        return

    ans = [0] * (n + 1)
    
    dfs1(0, 0, 0, 0)
    dfs2(n // 2, 0, 0, 0)
    
    for k in range(1, n + 1):
        print(ans[k])

if __name__ == "__main__":
    main()

算法及复杂度

  • 算法:中间相遇 (Meet-in-the-Middle)
  • 时间复杂度
    • dfs1 遍历第一半的所有子集,需要 的时间,每次哈希表操作平均为 (Python/Java) 或 (C++)。
    • dfs2 遍历第二半的所有子集,也需要 的时间。在每次 dfs2 到达叶子节点时,进行一次哈希表查找,然后遍历一个最多只有 个元素的小哈希表。因此,总时间复杂度为
  • 空间复杂度
    • 主要开销在于存储第一半指令的所有组合结果。在最坏情况下,可能会有 个不同的状态(位移X, 位移Y, 组合数),因此需要相应的空间来存储 half1_results