量子门
题意
有 个量子比特,初始状态用一个长度为
的 01 向量表示。有
个量子门可以操作,施加第
个门会翻转比特
本身,同时还会翻转所有与它有纠缠关系的比特
(给定
条纠缠关系
表示门
额外翻转比特
)。每个门最多用一次(用两次等于没用)。
要求选出一个门的子集,使得施加后所有比特变为 0。若无解输出 -1;若有解,输出操作数最少、字典序最小的方案(升序输出门编号)。
数据范围: ,
思路
每个门要么用要么不用,用了就把一组比特全部异或 1——这不就是 (模 2 域)上的线性方程组吗?
把每个门 的效果写成一个列向量
,其中
表示门
会翻转比特
。设
表示是否使用门
,问题就是:
$$
其中 是初始状态向量。这是一个
的
线性方程组。
第一步:高斯消元
对增广矩阵做 上的高斯消元(行简化阶梯形)。消元后:
- 如果出现
的矛盾行,说明无解,输出
-1。 - 否则得到秩
,有
个自由变量。
第二步:枚举自由变量
高斯消元之后,每个主元变量(pivot variable)可以用自由变量来表达。自由变量的每一种取值组合都对应一组合法解。
我们需要在所有合法解中找最少门数,再取字典序最小的。怎么做?枚举自由变量的所有 种赋值,每次
算出完整解,统计 1 的个数(门数)和对应的位掩码(字典序)。
,实际测试中自由变量通常不会太多(
就完全可以暴力枚举)。
字典序的判断技巧
把解表示成一个 位的 bitmask(第 0 位 = 门 1,第 1 位 = 门 2 ...)。对于两个门数相同的解,bitmask 数值更小的那个,恰好就是字典序更小的——因为低位对应编号更小的门,优先选低编号门会使数值更小。
复杂度
- 高斯消元:
(用位运算压缩为
)
- 枚举自由变量:
,其中
代码
#include <bits/stdc++.h>
using namespace std;
int main(){
int n, m;
scanf("%d%d", &n, &m);
vector<int> state(n);
for(int i = 0; i < n; i++) scanf("%d", &state[i]);
// col[j] = 门 j 翻转哪些比特的位掩码
vector<long long> col(n, 0);
for(int j = 0; j < n; j++) col[j] = 1LL << j;
for(int i = 0; i < m; i++){
int u, v; scanf("%d%d", &u, &v); u--; v--;
col[u] |= 1LL << v;
}
long long target = 0;
for(int i = 0; i < n; i++) if(state[i]) target |= 1LL << i;
if(target == 0){ printf("\n"); return 0; }
// 增广矩阵:row[i] 的前 n 位是各门对第 i 位的系数,第 n 位是 target
vector<long long> row(n, 0);
for(int i = 0; i < n; i++){
for(int j = 0; j < n; j++)
if(col[j] & (1LL << i)) row[i] |= 1LL << j;
if(target & (1LL << i)) row[i] |= 1LL << n;
}
// GF(2) 高斯消元
vector<int> pivotCol(n, -1), pivotRow(n, -1);
int rk = 0;
for(int c = 0; c < n && rk < n; c++){
int pr = -1;
for(int r = rk; r < n; r++)
if(row[r] & (1LL << c)){ pr = r; break; }
if(pr == -1) continue;
swap(row[rk], row[pr]);
pivotCol[rk] = c; pivotRow[c] = rk;
for(int r = 0; r < n; r++)
if(r != rk && (row[r] & (1LL << c))) row[r] ^= row[rk];
rk++;
}
// 检查是否有矛盾
for(int r = rk; r < n; r++)
if(row[r] & (1LL << n)){ printf("-1\n"); return 0; }
// 自由变量
vector<int> fv;
for(int c = 0; c < n; c++) if(pivotRow[c] == -1) fv.push_back(c);
int nf = fv.size();
long long bestMask = -1;
int bestCnt = n + 1;
// 枚举自由变量的所有赋值
for(long long fm = 0; fm < (1LL << nf); fm++){
long long sol = 0;
for(int i = 0; i < nf; i++)
if(fm & (1LL << i)) sol |= 1LL << fv[i];
for(int r = 0; r < rk; r++){
int val = (row[r] >> n) & 1;
for(int i = 0; i < nf; i++)
if((fm & (1LL << i)) && (row[r] & (1LL << fv[i]))) val ^= 1;
if(val) sol |= 1LL << pivotCol[r];
}
int cnt = __builtin_popcountll(sol);
if(cnt < bestCnt || (cnt == bestCnt && sol < bestMask)){
bestCnt = cnt; bestMask = sol;
}
}
bool first = true;
for(int i = 0; i < n; i++)
if(bestMask & (1LL << i)){ if(!first) printf(" "); printf("%d", i+1); first = false; }
printf("\n");
return 0;
}
import java.util.*;
public class Main {
public static void main(String[] args) {
Scanner sc = new Scanner(System.in);
int n = sc.nextInt(), m = sc.nextInt();
int[] state = new int[n];
for (int i = 0; i < n; i++) state[i] = sc.nextInt();
long[] col = new long[n];
for (int j = 0; j < n; j++) col[j] = 1L << j;
for (int i = 0; i < m; i++) {
int u = sc.nextInt() - 1, v = sc.nextInt() - 1;
col[u] |= 1L << v;
}
long target = 0;
for (int i = 0; i < n; i++) if (state[i] == 1) target |= 1L << i;
if (target == 0) { System.out.println(); return; }
long[] row = new long[n];
for (int i = 0; i < n; i++) {
for (int j = 0; j < n; j++)
if ((col[j] & (1L << i)) != 0) row[i] |= 1L << j;
if ((target & (1L << i)) != 0) row[i] |= 1L << n;
}
int[] pivotCol = new int[n], pivotRow = new int[n];
Arrays.fill(pivotCol, -1); Arrays.fill(pivotRow, -1);
int rk = 0;
for (int c = 0; c < n && rk < n; c++) {
int pr = -1;
for (int r = rk; r < n; r++)
if ((row[r] & (1L << c)) != 0) { pr = r; break; }
if (pr == -1) continue;
long tmp = row[rk]; row[rk] = row[pr]; row[pr] = tmp;
pivotCol[rk] = c; pivotRow[c] = rk;
for (int r = 0; r < n; r++)
if (r != rk && (row[r] & (1L << c)) != 0) row[r] ^= row[rk];
rk++;
}
for (int r = rk; r < n; r++)
if ((row[r] & (1L << n)) != 0) { System.out.println(-1); return; }
List<Integer> fv = new ArrayList<>();
for (int c = 0; c < n; c++) if (pivotRow[c] == -1) fv.add(c);
int nf = fv.size();
long bestMask = -1; int bestCnt = n + 1;
for (long fm = 0; fm < (1L << nf); fm++) {
long sol = 0;
for (int i = 0; i < nf; i++)
if ((fm & (1L << i)) != 0) sol |= 1L << fv.get(i);
for (int r = 0; r < rk; r++) {
int val = (int)((row[r] >> n) & 1);
for (int i = 0; i < nf; i++)
if ((fm & (1L << i)) != 0 && (row[r] & (1L << fv.get(i))) != 0) val ^= 1;
if (val == 1) sol |= 1L << pivotCol[r];
}
int cnt = Long.bitCount(sol);
if (cnt < bestCnt || (cnt == bestCnt && Long.compareUnsigned(sol, bestMask) < 0)) {
bestCnt = cnt; bestMask = sol;
}
}
StringBuilder sb = new StringBuilder();
for (int i = 0; i < n; i++)
if ((bestMask & (1L << i)) != 0) { if (sb.length() > 0) sb.append(' '); sb.append(i + 1); }
System.out.println(sb);
}
}
import sys
input = sys.stdin.readline
def main():
n, m = map(int, input().split())
state = list(map(int, input().split()))
col = [1 << j for j in range(n)]
for _ in range(m):
u, v = map(int, input().split())
col[u-1] |= 1 << (v-1)
target = sum(state[i] << i for i in range(n))
if target == 0:
print(); return
row = [0] * n
for i in range(n):
for j in range(n):
if col[j] & (1 << i): row[i] |= 1 << j
if target & (1 << i): row[i] |= 1 << n
pivotCol = [-1] * n; pivotRow = [-1] * n; rk = 0
for c in range(n):
if rk >= n: break
pr = -1
for r in range(rk, n):
if row[r] & (1 << c): pr = r; break
if pr == -1: continue
row[rk], row[pr] = row[pr], row[rk]
pivotCol[rk] = c; pivotRow[c] = rk
for r in range(n):
if r != rk and row[r] & (1 << c): row[r] ^= row[rk]
rk += 1
for r in range(rk, n):
if row[r] & (1 << n): print(-1); return
fv = [c for c in range(n) if pivotRow[c] == -1]
nf = len(fv)
bestMask, bestCnt = -1, n + 1
for fm in range(1 << nf):
sol = 0
for i in range(nf):
if fm & (1 << i): sol |= 1 << fv[i]
for r in range(rk):
val = (row[r] >> n) & 1
for i in range(nf):
if (fm & (1 << i)) and (row[r] & (1 << fv[i])): val ^= 1
if val: sol |= 1 << pivotCol[r]
cnt = bin(sol).count('1')
if cnt < bestCnt or (cnt == bestCnt and sol < bestMask):
bestCnt, bestMask = cnt, sol
print(' '.join(str(i+1) for i in range(n) if bestMask & (1 << i)))
main()
const readline = require('readline');
const rl = readline.createInterface({ input: process.stdin });
const lines = [];
rl.on('line', l => lines.push(l.trim()));
rl.on('close', () => {
let idx = 0;
const [n, m] = lines[idx++].split(' ').map(Number);
const state = lines[idx++].split(' ').map(Number);
const col = Array.from({length: n}, (_, j) => 1n << BigInt(j));
for (let i = 0; i < m; i++) {
const [u, v] = lines[idx++].split(' ').map(x => Number(x) - 1);
col[u] |= 1n << BigInt(v);
}
let target = 0n;
for (let i = 0; i < n; i++) if (state[i]) target |= 1n << BigInt(i);
if (target === 0n) { console.log(''); return; }
const row = new Array(n).fill(0n);
for (let i = 0; i < n; i++) {
for (let j = 0; j < n; j++)
if (col[j] & (1n << BigInt(i))) row[i] |= 1n << BigInt(j);
if (target & (1n << BigInt(i))) row[i] |= 1n << BigInt(n);
}
const pivotCol = new Array(n).fill(-1), pivotRow = new Array(n).fill(-1);
let rk = 0;
for (let c = 0; c < n && rk < n; c++) {
let pr = -1;
for (let r = rk; r < n; r++)
if (row[r] & (1n << BigInt(c))) { pr = r; break; }
if (pr === -1) continue;
[row[rk], row[pr]] = [row[pr], row[rk]];
pivotCol[rk] = c; pivotRow[c] = rk;
for (let r = 0; r < n; r++)
if (r !== rk && (row[r] & (1n << BigInt(c)))) row[r] ^= row[rk];
rk++;
}
for (let r = rk; r < n; r++)
if (row[r] & (1n << BigInt(n))) { console.log(-1); return; }
const fv = [];
for (let c = 0; c < n; c++) if (pivotRow[c] === -1) fv.push(c);
const nf = fv.length;
function popcount(x) { let c = 0; while (x > 0n) { x &= x - 1n; c++; } return c; }
let bestMask = -1n, bestCnt = n + 1;
for (let fm = 0; fm < (1 << nf); fm++) {
let sol = 0n;
for (let i = 0; i < nf; i++)
if (fm & (1 << i)) sol |= 1n << BigInt(fv[i]);
for (let r = 0; r < rk; r++) {
let val = Number((row[r] >> BigInt(n)) & 1n);
for (let i = 0; i < nf; i++)
if ((fm & (1 << i)) && (row[r] & (1n << BigInt(fv[i])))) val ^= 1;
if (val) sol |= 1n << BigInt(pivotCol[r]);
}
const cnt = popcount(sol);
if (cnt < bestCnt || (cnt === bestCnt && sol < bestMask)) {
bestCnt = cnt; bestMask = sol;
}
}
const ans = [];
for (let i = 0; i < n; i++)
if (bestMask & (1n << BigInt(i))) ans.push(i + 1);
console.log(ans.join(' '));
});

京公网安备 11010502036488号