题目链接

密码锁

题目描述

有一个密码锁,由一个 的矩阵构成。该密码锁只有在每一列上,每个数均不相同的情况下,才能被打开。

此外,可以对矩阵的任意一行进行水平翻转操作,但每行最多只能翻转一次。

现在给出密码锁上的密码矩阵,请你编写一个程序判断是否可以打开该密码锁。如果可以,请给出一组翻转方案。

输入:

  • 第一行输入矩阵的大小
  • 随后 列输入密码矩阵上的数字。

输出:

  • 若无法打开密码锁,输出 "No"。
  • 若可以打开,第一行输出 "Yes",第二行输出需要翻转的行数 ,第三行输出 个正整数,代表需要翻转的行号(行号从 1 开始计)。若有多种方案,输出任意一种即可。

解题思路

这是一个典型的约束满足问题,其结构非常适合使用 2-SAT 模型来解决。

每行都有两种状态(翻转或不翻转),共有 种组合。当 较大时,暴力枚举所有组合是不可行的。2-SAT 提供了一个多项式时间复杂度的解法。

  1. 2-SAT 模型建立:

    • 我们为每一行 (从 0 到 )引入一个布尔变量
    • 为真 () 表示第 不翻转
    • 为真 () 表示第 翻转
  2. 约束转化:

    • 密码锁能打开的条件是:对于任意两行 ),它们在最终状态下不能在任何一列上有相同的数字。
    • 这个条件意味着,如果两种选择(例如,“ 不翻转”和“ 不翻转”)会导致冲突,那么这两种选择不能同时成立。
    • 为原始的第 行, 为翻转后的第 行。如果 在第 列有冲突(即 ),那么我们不能同时选择
    • 这个约束可以用逻辑表达式表示为 ,它等价于 。这是一个标准的 2-SAT 子句。
  3. 四种冲突与对应的子句: 对于每一对不同的行 ,我们需要检查以下四种潜在的冲突: a. 原行 vs 原行 : 若存在某列 使 ,则添加子句 。 b. 原行 vs 翻转行 : 若存在某列 使 ,则添加子句 。 c. 翻转行 vs 原行 : 若存在某列 使 ,则添加子句 。 d. 翻转行 vs 翻转行 : 若存在某列 使 ,则添加子句

  4. 求解 2-SAT:

    • 一个子句 等价于两个蕴含式:。我们可以根据这些蕴含式构建一个图。
    • 图中有 个节点,分别代表 。每个蕴含式 对应图中的一条有向边
    • 构建完图后,使用 Tarjan 算法寻找图中的所有强连通分量 (SCC)。
    • 2-SAT 问题有解的充要条件是:对于任何变量 ,其对应的两个节点 在同一个强连通分量中。
    • 如果满足此条件,则可以构造出一组解。一种构造方法是:比较 所在 SCC 的拓扑序。Tarjan 算法生成的 SCC 编号是逆拓扑序,所以我们可以直接比较 scc_id。如果 scc_id[2*i] < scc_id[2*i+1](Java/Python实现,C++实现是 >),则选择 (翻转),否则选择 (不翻转)。
  5. 建图优化 (从 ):

    • 上述第 3 点中,对每一对行 都进行检查,会导致构建一个拥有 条边的图,这在 较大时会超时。
    • 一个关键的优化是,对每一列 单独处理约束。在列 中,冲突的根源是某个数值出现了多次
    • 我们可以遍历每一列,用一个哈希表记录下每个数值对应的所有选择(例如,数值 100 可能由“第 行不翻转”或“第 行翻转”得到)。
    • 如果一个数值只出现了一次,它不会产生任何冲突。
    • 如果一个数值 v 出现了 次,对应 个选择(我们称之为原子命题 ),那么约束就变成了 “这 个选择中,最多只能有一个为真”,即 AtMostOne(p_1, ..., p_k)
    • 这个 AtMostOne 约束可以使用辅助变量,在 的时间和空间内转化为标准的 2-CNF 子句,从而避免了 的暴力两两配对。
    • 通过这个优化,对每一列建图的时间复杂度从 降为了 ,总的建图时间复杂度也就降为了
  6. 实现细节:

    • 为了处理 Python/Java 中递归深度过大的问题,Tarjan 算法采用了迭代(非递归)的实现方式。
    • 为了处理内存限制(MLE)问题,采用了“两阶段建图”:先遍历计算出所有边和所需的最大节点数,再一次性精确分配内存并构建邻接表。

谁出的牛魔卡常题,纯神人啊。
java 代码调了一万年,才极限通过。
python 代码,python3 超时,pypy3 内存超限。
烂活。

代码

#include <iostream>
#include <vector>
#include <numeric>
#include <algorithm>
#include <map>

using namespace std;

// 使用 vector 代替静态数组以支持动态数量的节点
vector<vector<int>> adj;
vector<int> dfn, low, scc_id;
vector<int> stk;
vector<bool> in_stk;
int timestamp, scc_count;

void tarjan(int u) {
    dfn[u] = low[u] = ++timestamp;
    stk.push_back(u);
    in_stk[u] = true;

    for (int v : adj[u]) {
        if (!dfn[v]) {
            tarjan(v);
            low[u] = min(low[u], low[v]);
        } else if (in_stk[v]) {
            low[u] = min(low[u], dfn[v]);
        }
    }

    if (dfn[u] == low[u]) {
        ++scc_count;
        int y;
        do {
            y = stk.back();
            stk.pop_back();
            in_stk[y] = false;
            scc_id[y] = scc_count;
        } while (y != u);
    }
}

// u^1
int neg(int u) {
    return u ^ 1;
}

void solve() {
    int n, m;
    cin >> n >> m;
    vector<vector<int>> matrix(n, vector<int>(m));
    vector<vector<int>> reversed_matrix(n, vector<int>(m));

    for (int i = 0; i < n; ++i) {
        for (int j = 0; j < m; ++j) {
            cin >> matrix[i][j];
        }
        reversed_matrix[i] = matrix[i];
        reverse(reversed_matrix[i].begin(), reversed_matrix[i].end());
    }

    // Pass 1: 收集所有边并确定节点总数
    vector<pair<int, int>> edge_list;
    int next_var_idx = n;
    int max_node = 2 * n - 1;

    for (int j = 0; j < m; ++j) {
        map<int, vector<int>> value_to_nodes;
        for (int i = 0; i < n; ++i) {
            value_to_nodes[matrix[i][j]].push_back(2 * i);
            value_to_nodes[reversed_matrix[i][j]].push_back(2 * i + 1);
        }

        for (auto const& [val, nodes] : value_to_nodes) {
            if (nodes.size() > 1) {
                int k = nodes.size();
                int base_aux_var = next_var_idx;
                next_var_idx += (k - 1);
                
                max_node = max(max_node, 2 * (base_aux_var + k - 2) + 1);

                // (!p1 V s1)
                edge_list.push_back({neg(nodes[0]), 2 * base_aux_var});
                // (!pk V !s_{k-1})
                edge_list.push_back({neg(nodes[k - 1]), neg(2 * (base_aux_var + k - 2))});
                
                for (int i = 1; i < k - 1; ++i) {
                    int p_i = nodes[i];
                    int s_i = 2 * (base_aux_var + i);
                    int s_i_minus_1 = 2 * (base_aux_var + i - 1);
                    edge_list.push_back({neg(p_i), s_i});
                    edge_list.push_back({neg(s_i_minus_1), s_i});
                    edge_list.push_back({neg(p_i), neg(s_i_minus_1)});
                }
            }
        }
    }

    // Pass 2: 构建图并求解
    int num_nodes = max_node + 1;
    adj.assign(num_nodes, vector<int>());
    for(const auto& edge : edge_list) {
        // (u V v) <=> (!u => v) AND (!v => u)
        adj[neg(edge.first)].push_back(edge.second);
        adj[neg(edge.second)].push_back(edge.first);
    }
    
    dfn.assign(num_nodes, 0);
    low.assign(num_nodes, 0);
    scc_id.assign(num_nodes, 0);
    in_stk.assign(num_nodes, false);
    stk.clear();
    timestamp = scc_count = 0;

    for (int i = 0; i < num_nodes; ++i) {
        if (!dfn[i]) {
            tarjan(i);
        }
    }

    for (int i = 0; i < n; ++i) {
        if (scc_id[2 * i] == scc_id[2 * i + 1] && scc_id[2 * i] != 0) {
            cout << "No" << endl;
            return;
        }
    }

    cout << "Yes" << endl;
    vector<int> flipped_rows;
    for (int i = 0; i < n; ++i) {
        // Tarjan SCC 编号是逆拓扑序, 编号小的拓扑序靠后
        // 我们选择赋值为真的那个文字, 其 SCC 编号必须更大
        if (scc_id[2 * i] < scc_id[2 * i + 1]) { // scc_id(!x_i) > scc_id(x_i)
            flipped_rows.push_back(i + 1);
        }
    }
    cout << flipped_rows.size() << endl;
    if (!flipped_rows.empty()) {
        for (int i = 0; i < flipped_rows.size(); ++i) {
            cout << flipped_rows[i] << (i == flipped_rows.size() - 1 ? "" : " ");
        }
    }
    cout << endl;
}

int main() {
    ios::sync_with_stdio(false);
    cin.tie(NULL);
    solve();
    return 0;
}

import java.util.*;
import java.io.*;

public class Main {
    // 2-SAT 图和 Tarjan 算法所需的全局变量
    static List<Integer>[] adj;
    static int[] dfn, low, sccId;
    static Stack<Integer> tarjanStack;
    static boolean[] inStack;
    static int timestamp, sccCount;
    static int n; // 行数
    static int nextVarIndex; // 下一个可用的辅助变量节点索引

    // 新的节点编号: 2*i (true), 2*i+1 (false)
    // neg(u) is u ^ 1
    static int neg(int u) {
        return u ^ 1;
    }

    // 添加 2-SAT 子句 (a V b), 等价于 (!a -> b) AND (!b -> a)
    static void addClause(int u, int v) {
        // (u V v) <=> (!u => v) AND (!v => u)
        adj[neg(u)].add(v);
        adj[neg(v)].add(u);
    }

    // Tarjan 算法的迭代实现
    static void tarjanIterative(int startNode, int numNodes) {
        Stack<int[]> dfsStack = new Stack<>();
        dfsStack.push(new int[]{startNode, 0});

        while (!dfsStack.isEmpty()) {
            int[] state = dfsStack.peek();
            int u = state[0];
            int neighborIdx = state[1];

            if (dfn[u] == 0) {
                timestamp++;
                dfn[u] = low[u] = timestamp;
                tarjanStack.push(u);
                inStack[u] = true;
            }

            if (neighborIdx < adj[u].size()) {
                int v = adj[u].get(neighborIdx);
                state[1]++;
                if (dfn[v] == 0) {
                    dfsStack.push(new int[]{v, 0});
                } else if (inStack[v]) {
                    low[u] = Math.min(low[u], dfn[v]);
                }
            } else {
                dfsStack.pop();
                if (dfn[u] == low[u]) {
                    sccCount++;
                    int node;
                    do {
                        node = tarjanStack.pop();
                        inStack[node] = false;
                        sccId[node] = sccCount;
                    } while (node != u);
                }
                if (!dfsStack.isEmpty()) {
                    int parent = dfsStack.peek()[0];
                    low[parent] = Math.min(low[parent], low[u]);
                }
            }
        }
    }

    public static void main(String[] args) throws IOException {
        BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
        PrintWriter pw = new PrintWriter(new BufferedWriter(new OutputStreamWriter(System.out)));
        StringTokenizer st = new StringTokenizer(br.readLine());
        n = Integer.parseInt(st.nextToken());
        int m = Integer.parseInt(st.nextToken());
        
        int[][] matrix = new int[n][m];
        int[][] reversedMatrix = new int[n][m];

        for (int i = 0; i < n; i++) {
            st = new StringTokenizer(br.readLine());
            for (int j = 0; j < m; j++) {
                matrix[i][j] = Integer.parseInt(st.nextToken());
            }
            for (int j = 0; j < m; j++) {
                reversedMatrix[i][j] = matrix[i][m - 1 - j];
            }
        }
        
        // 两阶段建图以优化内存
        List<int[]> edgeList = new ArrayList<>();
        int nextVarIndex = n;
        int maxNode = 2 * n - 1;

        Map<Integer, List<Integer>> valueToNodes = new HashMap<>(n * 2);

        for (int j = 0; j < m; j++) {
            valueToNodes.clear();
            for (int i = 0; i < n; i++) {
                valueToNodes.computeIfAbsent(matrix[i][j], k -> new ArrayList<>()).add(2 * i);
                valueToNodes.computeIfAbsent(reversedMatrix[i][j], k -> new ArrayList<>()).add(2 * i + 1);
            }
            
            for (List<Integer> nodes : valueToNodes.values()) {
                if (nodes.size() > 1) {
                    int k = nodes.size();
                    int[] s = new int[k - 1];
                    for (int i = 0; i < k - 1; i++) s[i] = nextVarIndex++;
                    
                    maxNode = Math.max(maxNode, 2 * s[k-2] + 1);

                    // (!p1 V s1)
                    edgeList.add(new int[]{neg(nodes.get(0)), 2 * s[0]});
                    // (!pk V !s_{k-1})
                    edgeList.add(new int[]{neg(nodes.get(k - 1)), neg(2 * (s[k - 2]))});
                    for (int i = 1; i < k - 1; i++) {
                        edgeList.add(new int[]{neg(nodes.get(i)), 2 * s[i]});
                        edgeList.add(new int[]{neg(2 * s[i - 1]), 2 * s[i]});
                        edgeList.add(new int[]{neg(nodes.get(i)), neg(2 * s[i - 1])});
                    }
                }
            }
        }
        
        int numNodes = maxNode + 1;
        adj = new ArrayList[numNodes];
        for (int i = 0; i < numNodes; i++) adj[i] = new ArrayList<>();
        for (int[] edge : edgeList) {
            adj[neg(edge[0])].add(edge[1]);
            adj[neg(edge[1])].add(edge[0]);
        }
        
        dfn = new int[numNodes];
        low = new int[numNodes];
        sccId = new int[numNodes];
        tarjanStack = new Stack<>();
        inStack = new boolean[numNodes];
        
        for (int i = 0; i < numNodes; i++) {
            if (dfn[i] == 0) tarjanIterative(i, numNodes);
        }

        for (int i = 0; i < n; i++) {
            if (sccId[2 * i] == sccId[2 * i + 1] && sccId[2 * i] != 0) {
                pw.println("No");
                pw.flush();
                return;
            }
        }

        pw.println("Yes");
        List<Integer> flippedRows = new ArrayList<>();
        for (int i = 0; i < n; i++) {
            if (sccId[2 * i] < sccId[2 * i + 1]) {
                flippedRows.add(i + 1);
            }
        }
        pw.println(flippedRows.size());
        if (!flippedRows.isEmpty()) {
            StringBuilder sb = new StringBuilder();
            for (int i = 0; i < flippedRows.size(); i++) {
                sb.append(flippedRows.get(i));
                if (i < flippedRows.size() - 1) sb.append(" ");
            }
            pw.println(sb.toString());
        } else {
             pw.println();
        }
        pw.flush();
    }
}

import sys
from collections import defaultdict

def solve():
   try:
       # 极限 I/O 优化:一次性读取所有输入,然后在内存中处理
       lines = sys.stdin.read().splitlines()
       if not lines: return

       n_str, m_str = lines[0].split()
       n, m = int(n_str), int(m_str)
       
       matrix = []
       for i in range(1, n + 1):
           matrix.append(list(map(int, lines[i].split())))
           
   except (IOError, ValueError, IndexError):
       return

   reversed_matrix = [row[::-1] for row in matrix]
   
   edge_list = []
   next_var_idx = n
   max_node = 2 * n - 1 # 确保 max_node 在循环外初始化

   def neg(u): return u ^ 1

   for j in range(m):
       value_to_nodes = defaultdict(list)
       for i in range(n):
           value_to_nodes[matrix[i][j]].append(2 * i)
           value_to_nodes[reversed_matrix[i][j]].append(2 * i + 1)

       for nodes in value_to_nodes.values():
           k = len(nodes)
           if k > 1:
               base_aux_var = next_var_idx
               next_var_idx += (k - 1)
               
               max_node = max(max_node, 2 * (base_aux_var + k - 2) + 1)
               
               edge_list.append((neg(nodes[0]), 2 * base_aux_var))
               edge_list.append((neg(nodes[k-1]), neg(2 * (base_aux_var + k - 2))))
               
               for i in range(1, k - 1):
                   p_i = nodes[i]
                   s_i = 2 * (base_aux_var + i)
                   s_i_minus_1 = 2 * (base_aux_var + i - 1)
                   
                   edge_list.append((neg(p_i), s_i))
                   edge_list.append((neg(s_i_minus_1), s_i))
                   edge_list.append((neg(p_i), neg(s_i_minus_1)))
   
   num_nodes = max_node + 1
   adj = [[] for _ in range(num_nodes)]
   for u, v in edge_list:
       adj[neg(u)].append(v)
       adj[neg(v)].append(u)

   dfn, low, scc_id = [0] * num_nodes, [0] * num_nodes, [-1] * num_nodes
   tarjan_stack, in_stack = [], [False] * num_nodes
   timestamp, scc_count = 0, 0

   for i in range(num_nodes):
       if dfn[i] == 0:
           dfs_stack = [(i, iter(adj[i]))]
           while dfs_stack:
               u, neighbors = dfs_stack[-1]
               if dfn[u] == 0:
                   timestamp += 1
                   dfn[u] = low[u] = timestamp
                   tarjan_stack.append(u)
                   in_stack[u] = True
               try:
                   v = next(neighbors)
                   if dfn[v] == 0:
                       dfs_stack.append((v, iter(adj[v])))
                   elif in_stack[v]:
                       low[u] = min(low[u], dfn[v])
               except StopIteration:
                   dfs_stack.pop()
                   if dfn[u] == low[u]:
                       scc_count += 1
                       while True:
                           node = tarjan_stack.pop()
                           in_stack[node] = False
                           scc_id[node] = scc_count
                           if node == u: break
                   if dfs_stack:
                       parent, _ = dfs_stack[-1]
                       low[parent] = min(low[parent], low[u])

   for i in range(n):
       if scc_id[2*i] == scc_id[2*i+1] and scc_id[2*i] != -1:
           sys.stdout.write("No\n")
           return

   sys.stdout.write("Yes\n")
   flipped_rows = [i + 1 for i in range(n) if scc_id[2*i] < scc_id[2*i+1]]
   sys.stdout.write(str(len(flipped_rows)) + "\n")
   if flipped_rows:
       sys.stdout.write(' '.join(map(str, flipped_rows)) + "\n")
   else:
       sys.stdout.write("\n")

solve()

算法及复杂度

  • 算法:带优化建图的 2-SAT (2-Satisfiability)。
  • 时间复杂度:
    • 建图: 对于 列中的每一列,我们遍历 行来填充哈希表,这需要 。然后,我们遍历哈希表中的值。对于一个出现 次的值,我们使用辅助变量来添加 条约束(边)。由于一列中所有值的出现次数总和为 ,所以处理一列的总时间是 。因此,总的建图(收集所有边)时间为
    • 求解: 构建邻接表需要 的时间,其中顶点数 和边数 都是 级别。运行 Tarjan 算法寻找强连通分量的时间复杂度也是
    • 综上,总时间复杂度由建图和求解主导,为
  • 空间复杂度:
    • 空间主要用于存储图的邻接表以及 Tarjan 算法所需的各种辅助数组(dfn, low, scc_id 等)。
    • 顶点数和边数都是 级别,因此空间复杂度为