题目链接

建物流中转站

题目描述

给定一个二维平面网格,1 代表房子,0 代表空地。

需要找到一个空地来修建一个物流中转站,使得这个中转站到所有房子的曼哈顿距离之和最小。

  • 曼哈顿距离:
  • 如果能修建,返回最小的距离和。
  • 如果网格中没有空地,返回 -1。

解题思路

这是一个经典的优化问题,可以通过维度分离动态规划的思想,在 的时间内高效解决,其中 是网格的边长。

1. 核心思想:维度分离

我们希望最小化的目标函数是所有房子到中转站 (r, c) 的距离之和 S

这个公式可以被拆分为两个完全独立的一维问题:

这意味着,我们可以独立地计算如果中转站建在第 r 行的总行距,和建在第 c 列的总列距。任何一个候选位置 (r, c) 的总距离和就是这两部分之和。

2. 高效的距离和计算

暴力做法是遍历每个空地,再遍历所有房子来计算距离和,复杂度为 ,效率低下。

我们可以预先计算出中转站建在任意一行或一列时的距离和,从而将查询成本降为

算法步骤:

  1. 预处理

    • 遍历一次网格,统计每一行 row_counts[i] 和每一列 col_counts[j] 的房子数量。
    • 同时记录房子的总数 total_houses 和空地的总数 empty_lands
    • 如果 empty_lands 为 0,直接返回 -1。
  2. 计算行距离数组 dist_r

    • dist_r[i] 表示如果中转站建在第 i 行,它到所有房子的行距离之和。
    • 初始化:首先计算 dist_r[0]。`dist_r[0] = \sum_{k=0}^{N-1} k \cdot \text{row_counts}[k]$。
    • 递推:利用 dist_r[i-1] 时间内计算 dist_r[i]。当中转站从第 i-1 行移动到第 i 行时:
      • 对于所有在 i 行上方的房子,距离都增加了1。
      • 对于所有在 i 行及下方的房子,距离都减少了1。
      • 递推公式为:dist_r[i] = dist_r[i-1] + houses_above - houses_below
  3. 计算列距离数组 dist_c

    • 使用与计算 dist_r 完全相同的方法,计算出 dist_c 数组。
  4. 寻找最优解

    • 再次遍历网格。对于每一个是空地 (0) 的格子 (i, j),其对应的总距离和就是 dist_r[i] + dist_c[j]
    • 在所有空地对应的距离和中,找到最小值即为答案。

这个算法的总时间复杂度和空间复杂度都是 ,可以高效地处理非常大的网格。

代码

#include <iostream>
#include <vector>
#include <numeric>
#include <algorithm>

using namespace std;

int main() {
    ios_base::sync_with_stdio(false);
    cin.tie(NULL);

    int n;
    cin >> n;

    vector<vector<int>> grid(n, vector<int>(n));
    vector<long long> row_counts(n, 0);
    vector<long long> col_counts(n, 0);
    long long total_houses = 0;
    bool has_empty_land = false;

    for (int i = 0; i < n; ++i) {
        for (int j = 0; j < n; ++j) {
            cin >> grid[i][j];
            if (grid[i][j] == 1) {
                row_counts[i]++;
                col_counts[j]++;
                total_houses++;
            } else {
                has_empty_land = true;
            }
        }
    }

    if (!has_empty_land) {
        cout << -1 << endl;
        return 0;
    }
    if (total_houses == 0) {
        cout << 0 << endl;
        return 0;
    }

    vector<long long> dist_r(n, 0);
    vector<long long> dist_c(n, 0);

    // 计算 dist_r[0] 和 dist_c[0]
    for (int i = 0; i < n; ++i) {
        dist_r[0] += i * row_counts[i];
        dist_c[0] += i * col_counts[i];
    }

    // 递推计算 dist_r 和 dist_c
    long long houses_above = 0;
    long long houses_below = total_houses;
    for (int i = 1; i < n; ++i) {
        houses_above += row_counts[i - 1];
        houses_below -= row_counts[i - 1];
        dist_r[i] = dist_r[i - 1] + houses_above - houses_below;
    }

    long long houses_left = 0;
    long long houses_right = total_houses;
    for (int i = 1; i < n; ++i) {
        houses_left += col_counts[i - 1];
        houses_right -= col_counts[i - 1];
        dist_c[i] = dist_c[i - 1] + houses_left - houses_right;
    }

    long long min_dist = -1;

    for (int i = 0; i < n; ++i) {
        for (int j = 0; j < n; ++j) {
            if (grid[i][j] == 0) {
                long long current_dist = dist_r[i] + dist_c[j];
                if (min_dist == -1 || current_dist < min_dist) {
                    min_dist = current_dist;
                }
            }
        }
    }

    cout << min_dist << endl;

    return 0;
}
import java.util.Scanner;

public class Main {
    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);
        int n = sc.nextInt();

        int[][] grid = new int[n][n];
        long[] rowCounts = new long[n];
        long[] colCounts = new long[n];
        long totalHouses = 0;
        boolean hasEmptyLand = false;

        for (int i = 0; i < n; i++) {
            for (int j = 0; j < n; j++) {
                grid[i][j] = sc.nextInt();
                if (grid[i][j] == 1) {
                    rowCounts[i]++;
                    colCounts[j]++;
                    totalHouses++;
                } else {
                    hasEmptyLand = true;
                }
            }
        }

        if (!hasEmptyLand) {
            System.out.println(-1);
            return;
        }
        if (totalHouses == 0) {
            System.out.println(0);
            return;
        }

        long[] distR = new long[n];
        long[] distC = new long[n];

        for (int i = 0; i < n; i++) {
            distR[0] += (long)i * rowCounts[i];
            distC[0] += (long)i * colCounts[i];
        }

        long housesAbove = 0;
        long housesBelow = totalHouses;
        for (int i = 1; i < n; i++) {
            housesAbove += rowCounts[i - 1];
            housesBelow -= rowCounts[i - 1];
            distR[i] = distR[i - 1] + housesAbove - housesBelow;
        }

        long housesLeft = 0;
        long housesRight = totalHouses;
        for (int i = 1; i < n; i++) {
            housesLeft += colCounts[i - 1];
            housesRight -= colCounts[i - 1];
            distC[i] = distC[i - 1] + housesLeft - housesRight;
        }

        long minDist = -1;

        for (int i = 0; i < n; i++) {
            for (int j = 0; j < n; j++) {
                if (grid[i][j] == 0) {
                    long currentDist = distR[i] + distC[j];
                    if (minDist == -1 || currentDist < minDist) {
                        minDist = currentDist;
                    }
                }
            }
        }

        System.out.println(minDist);
    }
}
import sys

def solve():
    try:
        n_str = sys.stdin.readline()
        if not n_str: return
        n = int(n_str)
        
        grid = []
        row_counts = [0] * n
        col_counts = [0] * n
        total_houses = 0
        has_empty_land = False

        for i in range(n):
            row = list(map(int, sys.stdin.readline().split()))
            grid.append(row)
            for j in range(n):
                if row[j] == 1:
                    row_counts[i] += 1
                    col_counts[j] += 1
                    total_houses += 1
                else:
                    has_empty_land = True
        
        if not has_empty_land:
            print(-1)
            return
        if total_houses == 0:
            print(0)
            return

        dist_r = [0] * n
        dist_c = [0] * n

        for i in range(n):
            dist_r[0] += i * row_counts[i]
            dist_c[0] += i * col_counts[i]

        houses_above = 0
        houses_below = total_houses
        for i in range(1, n):
            houses_above += row_counts[i-1]
            houses_below -= row_counts[i-1]
            dist_r[i] = dist_r[i-1] + houses_above - houses_below

        houses_left = 0
        houses_right = total_houses
        for i in range(1, n):
            houses_left += col_counts[i-1]
            houses_right -= col_counts[i-1]
            dist_c[i] = dist_c[i-1] + houses_left - houses_right

        min_dist = -1

        for i in range(n):
            for j in range(n):
                if grid[i][j] == 0:
                    current_dist = dist_r[i] + dist_c[j]
                    if min_dist == -1 or current_dist < min_dist:
                        min_dist = current_dist
                        
        print(min_dist)
        
    except (IOError, ValueError):
        return

solve()

算法及复杂度

  • 算法:动态规划 / 维度分离

  • 时间复杂度: 。算法主要包括几次对 网格的遍历和对长度为 的数组的遍历,总的时间开销与网格大小成正比。

  • 空间复杂度: ,主要用于存储输入的网格。如果可以边读输入边处理,可以将空间优化到 ,只存储行列计数和距离数组。