量子门

题意

个量子比特,初始状态用一个长度为 的 01 向量表示。有 个量子门可以操作,施加第 个门会翻转比特 本身,同时还会翻转所有与它有纠缠关系的比特 (给定 条纠缠关系 表示门 额外翻转比特 )。每个门最多用一次(用两次等于没用)。

要求选出一个门的子集,使得施加后所有比特变为 0。若无解输出 -1;若有解,输出操作数最少、字典序最小的方案(升序输出门编号)。

数据范围:

思路

每个门要么用要么不用,用了就把一组比特全部异或 1——这不就是 (模 2 域)上的线性方程组吗?

把每个门 的效果写成一个列向量 ,其中 表示门 会翻转比特 。设 表示是否使用门 ,问题就是:

$$

其中 是初始状态向量。这是一个 线性方程组。

第一步:高斯消元

对增广矩阵做 上的高斯消元(行简化阶梯形)。消元后:

  • 如果出现 的矛盾行,说明无解,输出 -1
  • 否则得到秩 ,有 个自由变量。

第二步:枚举自由变量

高斯消元之后,每个主元变量(pivot variable)可以用自由变量来表达。自由变量的每一种取值组合都对应一组合法解。

我们需要在所有合法解中找最少门数,再取字典序最小的。怎么做?枚举自由变量的所有 种赋值,每次 算出完整解,统计 1 的个数(门数)和对应的位掩码(字典序)。

,实际测试中自由变量通常不会太多( 就完全可以暴力枚举)。

字典序的判断技巧

把解表示成一个 位的 bitmask(第 0 位 = 门 1,第 1 位 = 门 2 ...)。对于两个门数相同的解,bitmask 数值更小的那个,恰好就是字典序更小的——因为低位对应编号更小的门,优先选低编号门会使数值更小。

复杂度

  • 高斯消元:(用位运算压缩为
  • 枚举自由变量:,其中

代码

#include <bits/stdc++.h>
using namespace std;

int main(){
    int n, m;
    scanf("%d%d", &n, &m);
    vector<int> state(n);
    for(int i = 0; i < n; i++) scanf("%d", &state[i]);

    // col[j] = 门 j 翻转哪些比特的位掩码
    vector<long long> col(n, 0);
    for(int j = 0; j < n; j++) col[j] = 1LL << j;
    for(int i = 0; i < m; i++){
        int u, v; scanf("%d%d", &u, &v); u--; v--;
        col[u] |= 1LL << v;
    }

    long long target = 0;
    for(int i = 0; i < n; i++) if(state[i]) target |= 1LL << i;
    if(target == 0){ printf("\n"); return 0; }

    // 增广矩阵:row[i] 的前 n 位是各门对第 i 位的系数,第 n 位是 target
    vector<long long> row(n, 0);
    for(int i = 0; i < n; i++){
        for(int j = 0; j < n; j++)
            if(col[j] & (1LL << i)) row[i] |= 1LL << j;
        if(target & (1LL << i)) row[i] |= 1LL << n;
    }

    // GF(2) 高斯消元
    vector<int> pivotCol(n, -1), pivotRow(n, -1);
    int rk = 0;
    for(int c = 0; c < n && rk < n; c++){
        int pr = -1;
        for(int r = rk; r < n; r++)
            if(row[r] & (1LL << c)){ pr = r; break; }
        if(pr == -1) continue;
        swap(row[rk], row[pr]);
        pivotCol[rk] = c; pivotRow[c] = rk;
        for(int r = 0; r < n; r++)
            if(r != rk && (row[r] & (1LL << c))) row[r] ^= row[rk];
        rk++;
    }

    // 检查是否有矛盾
    for(int r = rk; r < n; r++)
        if(row[r] & (1LL << n)){ printf("-1\n"); return 0; }

    // 自由变量
    vector<int> fv;
    for(int c = 0; c < n; c++) if(pivotRow[c] == -1) fv.push_back(c);
    int nf = fv.size();

    long long bestMask = -1;
    int bestCnt = n + 1;

    // 枚举自由变量的所有赋值
    for(long long fm = 0; fm < (1LL << nf); fm++){
        long long sol = 0;
        for(int i = 0; i < nf; i++)
            if(fm & (1LL << i)) sol |= 1LL << fv[i];
        for(int r = 0; r < rk; r++){
            int val = (row[r] >> n) & 1;
            for(int i = 0; i < nf; i++)
                if((fm & (1LL << i)) && (row[r] & (1LL << fv[i]))) val ^= 1;
            if(val) sol |= 1LL << pivotCol[r];
        }
        int cnt = __builtin_popcountll(sol);
        if(cnt < bestCnt || (cnt == bestCnt && sol < bestMask)){
            bestCnt = cnt; bestMask = sol;
        }
    }

    bool first = true;
    for(int i = 0; i < n; i++)
        if(bestMask & (1LL << i)){ if(!first) printf(" "); printf("%d", i+1); first = false; }
    printf("\n");
    return 0;
}
import java.util.*;

public class Main {
    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);
        int n = sc.nextInt(), m = sc.nextInt();
        int[] state = new int[n];
        for (int i = 0; i < n; i++) state[i] = sc.nextInt();

        long[] col = new long[n];
        for (int j = 0; j < n; j++) col[j] = 1L << j;
        for (int i = 0; i < m; i++) {
            int u = sc.nextInt() - 1, v = sc.nextInt() - 1;
            col[u] |= 1L << v;
        }

        long target = 0;
        for (int i = 0; i < n; i++) if (state[i] == 1) target |= 1L << i;
        if (target == 0) { System.out.println(); return; }

        long[] row = new long[n];
        for (int i = 0; i < n; i++) {
            for (int j = 0; j < n; j++)
                if ((col[j] & (1L << i)) != 0) row[i] |= 1L << j;
            if ((target & (1L << i)) != 0) row[i] |= 1L << n;
        }

        int[] pivotCol = new int[n], pivotRow = new int[n];
        Arrays.fill(pivotCol, -1); Arrays.fill(pivotRow, -1);
        int rk = 0;
        for (int c = 0; c < n && rk < n; c++) {
            int pr = -1;
            for (int r = rk; r < n; r++)
                if ((row[r] & (1L << c)) != 0) { pr = r; break; }
            if (pr == -1) continue;
            long tmp = row[rk]; row[rk] = row[pr]; row[pr] = tmp;
            pivotCol[rk] = c; pivotRow[c] = rk;
            for (int r = 0; r < n; r++)
                if (r != rk && (row[r] & (1L << c)) != 0) row[r] ^= row[rk];
            rk++;
        }

        for (int r = rk; r < n; r++)
            if ((row[r] & (1L << n)) != 0) { System.out.println(-1); return; }

        List<Integer> fv = new ArrayList<>();
        for (int c = 0; c < n; c++) if (pivotRow[c] == -1) fv.add(c);
        int nf = fv.size();

        long bestMask = -1; int bestCnt = n + 1;
        for (long fm = 0; fm < (1L << nf); fm++) {
            long sol = 0;
            for (int i = 0; i < nf; i++)
                if ((fm & (1L << i)) != 0) sol |= 1L << fv.get(i);
            for (int r = 0; r < rk; r++) {
                int val = (int)((row[r] >> n) & 1);
                for (int i = 0; i < nf; i++)
                    if ((fm & (1L << i)) != 0 && (row[r] & (1L << fv.get(i))) != 0) val ^= 1;
                if (val == 1) sol |= 1L << pivotCol[r];
            }
            int cnt = Long.bitCount(sol);
            if (cnt < bestCnt || (cnt == bestCnt && Long.compareUnsigned(sol, bestMask) < 0)) {
                bestCnt = cnt; bestMask = sol;
            }
        }

        StringBuilder sb = new StringBuilder();
        for (int i = 0; i < n; i++)
            if ((bestMask & (1L << i)) != 0) { if (sb.length() > 0) sb.append(' '); sb.append(i + 1); }
        System.out.println(sb);
    }
}
import sys
input = sys.stdin.readline

def main():
    n, m = map(int, input().split())
    state = list(map(int, input().split()))

    col = [1 << j for j in range(n)]
    for _ in range(m):
        u, v = map(int, input().split())
        col[u-1] |= 1 << (v-1)

    target = sum(state[i] << i for i in range(n))
    if target == 0:
        print(); return

    row = [0] * n
    for i in range(n):
        for j in range(n):
            if col[j] & (1 << i): row[i] |= 1 << j
        if target & (1 << i): row[i] |= 1 << n

    pivotCol = [-1] * n; pivotRow = [-1] * n; rk = 0
    for c in range(n):
        if rk >= n: break
        pr = -1
        for r in range(rk, n):
            if row[r] & (1 << c): pr = r; break
        if pr == -1: continue
        row[rk], row[pr] = row[pr], row[rk]
        pivotCol[rk] = c; pivotRow[c] = rk
        for r in range(n):
            if r != rk and row[r] & (1 << c): row[r] ^= row[rk]
        rk += 1

    for r in range(rk, n):
        if row[r] & (1 << n): print(-1); return

    fv = [c for c in range(n) if pivotRow[c] == -1]
    nf = len(fv)
    bestMask, bestCnt = -1, n + 1

    for fm in range(1 << nf):
        sol = 0
        for i in range(nf):
            if fm & (1 << i): sol |= 1 << fv[i]
        for r in range(rk):
            val = (row[r] >> n) & 1
            for i in range(nf):
                if (fm & (1 << i)) and (row[r] & (1 << fv[i])): val ^= 1
            if val: sol |= 1 << pivotCol[r]
        cnt = bin(sol).count('1')
        if cnt < bestCnt or (cnt == bestCnt and sol < bestMask):
            bestCnt, bestMask = cnt, sol

    print(' '.join(str(i+1) for i in range(n) if bestMask & (1 << i)))

main()
const readline = require('readline');
const rl = readline.createInterface({ input: process.stdin });
const lines = [];
rl.on('line', l => lines.push(l.trim()));
rl.on('close', () => {
    let idx = 0;
    const [n, m] = lines[idx++].split(' ').map(Number);
    const state = lines[idx++].split(' ').map(Number);

    const col = Array.from({length: n}, (_, j) => 1n << BigInt(j));
    for (let i = 0; i < m; i++) {
        const [u, v] = lines[idx++].split(' ').map(x => Number(x) - 1);
        col[u] |= 1n << BigInt(v);
    }

    let target = 0n;
    for (let i = 0; i < n; i++) if (state[i]) target |= 1n << BigInt(i);
    if (target === 0n) { console.log(''); return; }

    const row = new Array(n).fill(0n);
    for (let i = 0; i < n; i++) {
        for (let j = 0; j < n; j++)
            if (col[j] & (1n << BigInt(i))) row[i] |= 1n << BigInt(j);
        if (target & (1n << BigInt(i))) row[i] |= 1n << BigInt(n);
    }

    const pivotCol = new Array(n).fill(-1), pivotRow = new Array(n).fill(-1);
    let rk = 0;
    for (let c = 0; c < n && rk < n; c++) {
        let pr = -1;
        for (let r = rk; r < n; r++)
            if (row[r] & (1n << BigInt(c))) { pr = r; break; }
        if (pr === -1) continue;
        [row[rk], row[pr]] = [row[pr], row[rk]];
        pivotCol[rk] = c; pivotRow[c] = rk;
        for (let r = 0; r < n; r++)
            if (r !== rk && (row[r] & (1n << BigInt(c)))) row[r] ^= row[rk];
        rk++;
    }

    for (let r = rk; r < n; r++)
        if (row[r] & (1n << BigInt(n))) { console.log(-1); return; }

    const fv = [];
    for (let c = 0; c < n; c++) if (pivotRow[c] === -1) fv.push(c);
    const nf = fv.length;

    function popcount(x) { let c = 0; while (x > 0n) { x &= x - 1n; c++; } return c; }

    let bestMask = -1n, bestCnt = n + 1;
    for (let fm = 0; fm < (1 << nf); fm++) {
        let sol = 0n;
        for (let i = 0; i < nf; i++)
            if (fm & (1 << i)) sol |= 1n << BigInt(fv[i]);
        for (let r = 0; r < rk; r++) {
            let val = Number((row[r] >> BigInt(n)) & 1n);
            for (let i = 0; i < nf; i++)
                if ((fm & (1 << i)) && (row[r] & (1n << BigInt(fv[i])))) val ^= 1;
            if (val) sol |= 1n << BigInt(pivotCol[r]);
        }
        const cnt = popcount(sol);
        if (cnt < bestCnt || (cnt === bestCnt && sol < bestMask)) {
            bestCnt = cnt; bestMask = sol;
        }
    }

    const ans = [];
    for (let i = 0; i < n; i++)
        if (bestMask & (1n << BigInt(i))) ans.push(i + 1);
    console.log(ans.join(' '));
});