题目链接

小O的矩阵变换

题目描述

给定两个大小为 的 01 矩阵 。每次操作可以选择矩阵 的某一行或某一列,并将其中的所有元素进行翻转(0 变 1,1 变 0)。求将矩阵 变换为矩阵 所需的最少操作次数。如果无法完成变换,则输出 -1。

解题思路

这是一个经典的矩阵变换问题,其核心在于理解操作的性质和它们之间的相互关系。

1. 操作性质

  • 幂等性: 对同一行或同一列操作两次,相当于没有操作。这意味着每一行和每一列最多只需要被操作一次。

  • 交换律: 操作的顺序不影响最终结果。先翻转第 行再翻转第 列,与先翻转第 列再翻转第 行的结果是相同的。

2. 核心思路:基准行/列

由于操作顺序无关,我们可以固定一个基准,通过确定对这个基准的操作来推导出所有其他必要的操作。最简单的方法是选择第一行作为基准。

对于第一行,我们只有两种可能性:不翻转它,或者翻转它。我们可以分别探讨这两种情况,并计算出各自的最小代价,然后取其中的较小值。

场景一:假设不翻转第一行

  1. 确定列操作: 我们的目标是让矩阵 的第一行 A[0] 变得和 B[0] 完全一样。既然我们决定了不翻转第一行,那么唯一能做的就是通过翻转列来达成目的。具体地,对于第一行的第 列:

    • 如果 A[0][j] == B[0][j],则第 不能翻转。

    • 如果 A[0][j] != B[0][j],则第 必须翻转。

    这样,我们就唯一地确定了对所有列的操作。

  2. 确定并验证行操作: 在确定了所有列的操作后,我们来处理剩下的每一行 (从 1 到 )。

    • 首先,根据已确定的列翻转,将 的第 A[i] 变换得到一个中间状态 A'[i]

    • 此时我们已不能再调整列操作。为了让 A'[i] 匹配 B 的第 B[i],我们只能选择是否翻转第 行。

    • 这里只有两种可能的情况是可解的:

      • A'[i]B[i] 完全相同。此时我们不需要翻转第 行。

      • A'[i]B[i] 完全相反(即 A'[i]B[i] 的按位取反)。此时我们必须翻转第 行。

    • 如果 A'[i]B[i] 既不相同也不相反,那么在当前“不翻转第一行”的假设下,该问题无解。

  3. 计算代价: 如果所有行都能通过上述验证,我们就找到了一个可行的方案。其总操作次数为(所有必须翻转的列数)+(所有必须翻转的行数)。

场景二:假设翻转第一行

这个场景的分析过程与场景一完全相同,唯一的区别是初始时我们先对第一行进行翻转,并计 1 次操作。之后,同样根据翻转后的第一行去确定所有列的操作,再验证并确定其他行的操作,最终计算出总代价。

最终答案

比较从这两个场景中得到的代价(如果某个场景无解,则其代价视为无穷大)。

  • 如果两个场景都无解,则输出 -1。

  • 否则,输出两个代价中的最小值。

代码

#include <iostream>
#include <vector>
#include <numeric>
#include <algorithm>

using namespace std;

const int INF = 1e9;

int solve() {
    int n;
    cin >> n;
    vector<vector<int>> a(n, vector<int>(n));
    vector<vector<int>> b(n, vector<int>(n));
    for (int i = 0; i < n; ++i) {
        for (int j = 0; j < n; ++j) {
            cin >> a[i][j];
        }
    }
    for (int i = 0; i < n; ++i) {
        for (int j = 0; j < n; ++j) {
            cin >> b[i][j];
        }
    }

    auto calculate_cost = [&](vector<vector<int>> current_a, int initial_ops) {
        int ops = initial_ops;
        vector<int> col_flips(n, 0);

        // 确定列操作
        for (int j = 0; j < n; ++j) {
            if (current_a[0][j] != b[0][j]) {
                col_flips[j] = 1;
                ops++;
            }
        }

        // 应用列操作并确定行操作
        for (int i = 0; i < n; ++i) {
            for (int j = 0; j < n; ++j) {
                if (col_flips[j]) {
                    current_a[i][j] ^= 1;
                }
            }
        }

        for (int i = 1; i < n; ++i) {
            bool all_same = true;
            for (int j = 0; j < n; ++j) {
                if (current_a[i][j] != b[i][j]) {
                    all_same = false;
                    break;
                }
            }
            if (all_same) continue;

            bool all_diff = true;
            for (int j = 0; j < n; ++j) {
                if (current_a[i][j] == b[i][j]) {
                    all_diff = false;
                    break;
                }
            }
            if (all_diff) {
                ops++;
            } else {
                return INF;
            }
        }
        return ops;
    };

    // 场景1: 不翻转第一行
    int cost1 = calculate_cost(a, 0);

    // 场景2: 翻转第一行
    vector<vector<int>> a_flipped = a;
    for (int j = 0; j < n; ++j) {
        a_flipped[0][j] ^= 1;
    }
    int cost2 = calculate_cost(a_flipped, 1);

    int min_ops = min(cost1, cost2);
    return (min_ops == INF) ? -1 : min_ops;
}

int main() {
    ios_base::sync_with_stdio(false);
    cin.tie(NULL);
    int t;
    cin >> t;
    while (t--) {
        cout << solve() << endl;
    }
    return 0;
}
import java.util.Scanner;

public class Main {
    private static final int INF = Integer.MAX_VALUE;
    private static int n;
    private static int[][] a, b;

    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);
        int t = sc.nextInt();
        while (t-- > 0) {
            n = sc.nextInt();
            a = new int[n][n];
            b = new int[n][n];
            for (int i = 0; i < n; i++) {
                for (int j = 0; j < n; j++) {
                    a[i][j] = sc.nextInt();
                }
            }
            for (int i = 0; i < n; i++) {
                for (int j = 0; j < n; j++) {
                    b[i][j] = sc.nextInt();
                }
            }
            solve();
        }
    }

    private static void solve() {
        // 场景1: 不翻转第一行
        int cost1 = calculateCost(false);

        // 场景2: 翻转第一行
        int cost2 = calculateCost(true);
        
        int minOps = Math.min(cost1, cost2);
        if (minOps == INF) {
            System.out.println(-1);
        } else {
            System.out.println(minOps);
        }
    }

    private static int calculateCost(boolean flipFirstRow) {
        int ops = 0;
        int[][] currentA = new int[n][n];
        for(int i = 0; i < n; i++) {
            System.arraycopy(a[i], 0, currentA[i], 0, n);
        }

        if (flipFirstRow) {
            ops++;
            for (int j = 0; j < n; j++) {
                currentA[0][j] ^= 1;
            }
        }

        int[] colFlips = new int[n];
        for (int j = 0; j < n; j++) {
            if (currentA[0][j] != b[0][j]) {
                colFlips[j] = 1;
                ops++;
            }
        }
        
        for (int i = 1; i < n; i++) {
            boolean allSame = true;
            boolean allDiff = true;
            for (int j = 0; j < n; j++) {
                int valA = currentA[i][j] ^ colFlips[j];
                if (valA != b[i][j]) allSame = false;
                if (valA == b[i][j]) allDiff = false;
            }

            if (allSame) {
                continue;
            } else if (allDiff) {
                ops++;
            } else {
                return INF;
            }
        }
        return ops;
    }
}
import sys

def solve():
    n = int(sys.stdin.readline())
    a = [list(map(int, sys.stdin.readline().split())) for _ in range(n)]
    b = [list(map(int, sys.stdin.readline().split())) for _ in range(n)]

    INF = float('inf')

    def calculate_cost(flip_first_row):
        ops = 0
        current_a = [row[:] for row in a]

        if flip_first_row:
            ops += 1
            for j in range(n):
                current_a[0][j] ^= 1

        col_flips = [0] * n
        for j in range(n):
            if current_a[0][j] != b[0][j]:
                col_flips[j] = 1
        
        # 计算行和列总操作
        total_ops = ops + sum(col_flips)

        for i in range(1, n):
            # 检查其他行是否可以通过一次行翻转匹配
            first_element_a = current_a[i][0] ^ col_flips[0]
            first_element_b = b[i][0]
            
            row_needs_flip = (first_element_a != first_element_b)

            for j in range(n):
                val_a = current_a[i][j] ^ col_flips[j]
                val_b = b[i][j]
                if (val_a ^ row_needs_flip) != val_b:
                    return INF
            
            if row_needs_flip:
                total_ops += 1
        
        return total_ops

    # 场景1: 不翻转第一行
    cost1 = calculate_cost(False)

    # 场景2: 翻转第一行
    cost2 = calculate_cost(True)
    
    result = min(cost1, cost2)
    
    if result == INF:
        print(-1)
    else:
        print(result)


def main():
    t = int(sys.stdin.readline())
    for _ in range(t):
        solve()

if __name__ == "__main__":
    main()

算法及复杂度

  • 算法: 枚举、构造

  • 时间复杂度: 对于每个测试用例,我们执行两次计算(翻转第一行和不翻转第一行)。在每次计算中,我们确定列翻转需要 ,然后遍历剩余的 行,每行需要 来验证。因此,单次计算的复杂度是 。总的时间复杂度为 ,其中 是测试用例的数量。

  • 空间复杂度: 需要存储两个 的矩阵,因此空间复杂度为