小A的线段(easy version)

[题目链接](https://www.nowcoder.com/practice/03950e31758643d2924601b6dd466c24)

思路

这道题要求从 条线段中选若干条,使得坐标轴上 的每个整数点都被至少两条选中的线段覆盖。求方案数对 取模。

关键观察:

最多只有 10 条线段,这意味着选择方案最多 种——直接枚举所有子集即可。

如何快速判定一个子集是否合法?

朴素做法是对每个子集扫一遍 的每个点,但 可达 ,总复杂度 虽然能过,却不够优雅。

注意到所有线段的端点最多只有 个,加上边界 ,排序去重后最多分出 等价区间——同一个区间内,每条线段要么全部覆盖,要么全部不覆盖。因此只需要检查这些区间就够了。

具体做法:

  1. 收集所有"断点":,以及每条线段的 。排序去重。
  2. 相邻断点之间构成一个区间,在 范围内的区间共 个。
  3. 对每个区间,预处理一个覆盖掩码 covmask——用一个 位的二进制数记录哪些线段覆盖了它。
  4. 枚举 个子集 mask,对每个区间检查 popcount(covmask & mask) >= 2 是否成立。全部满足就计数。

样例演示

条线段:

断点排序后为 ,产生 5 个区间

  • 点 1:被线段 2、4 覆盖,掩码 0110
  • 点 2:被线段 2、4 覆盖,掩码 0110
  • 点 3:被线段 2、3、4 覆盖,掩码 0111(实际上相邻区间可合并,但不影响正确性)
  • 点 4:被线段 1、2、3、4 覆盖,掩码 1111
  • 点 5:被线段 1、2、3 覆盖,掩码 1110

枚举 16 个子集后,满足每个区间覆盖 的有 3 个:1111(全选)、0111(去掉线段 1)、1110(去掉线段 4)。答案为 3。

复杂度

  • 时间复杂度,枚举子集并检查每个区间。
  • 空间复杂度,存储断点和掩码。

代码

#include <bits/stdc++.h>
using namespace std;
int main(){
    ios::sync_with_stdio(false);
    cin.tie(nullptr);
    int n, m;
    cin >> n >> m;
    vector<int> st(m), ed(m);
    for(int i = 0; i < m; i++) cin >> st[i] >> ed[i];

    // 收集断点,构造等价区间
    vector<int> bp;
    bp.push_back(1);
    bp.push_back(n + 1);
    for(int i = 0; i < m; i++){
        bp.push_back(st[i]);
        bp.push_back(ed[i] + 1);
    }
    sort(bp.begin(), bp.end());
    bp.erase(unique(bp.begin(), bp.end()), bp.end());

    // 建立 [1,n] 内的区间,并计算覆盖掩码
    vector<int> covmask;
    for(int j = 0; j + 1 < (int)bp.size(); j++){
        int l = max(bp[j], 1), r = min(bp[j + 1] - 1, n);
        if(l > r) continue;
        int mask = 0;
        for(int i = 0; i < m; i++)
            if(st[i] <= l && l <= ed[i])
                mask |= (1 << i);
        covmask.push_back(mask);
    }

    // 枚举所有子集
    long long ans = 0;
    for(int mask = 0; mask < (1 << m); mask++){
        bool ok = true;
        for(int cm : covmask){
            if(__builtin_popcount(cm & mask) < 2){ ok = false; break; }
        }
        if(ok) ans++;
    }
    cout << ans % 998244353 << "\n";
}
import java.util.*;

public class Main {
    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);
        int n = sc.nextInt(), m = sc.nextInt();
        int[] st = new int[m], ed = new int[m];
        for (int i = 0; i < m; i++) {
            st[i] = sc.nextInt();
            ed[i] = sc.nextInt();
        }

        TreeSet<Integer> bpSet = new TreeSet<>();
        bpSet.add(1);
        bpSet.add(n + 1);
        for (int i = 0; i < m; i++) {
            bpSet.add(st[i]);
            bpSet.add(ed[i] + 1);
        }
        int[] bp = new int[bpSet.size()];
        int idx = 0;
        for (int v : bpSet) bp[idx++] = v;

        List<Integer> covmask = new ArrayList<>();
        for (int j = 0; j + 1 < bp.length; j++) {
            int l = Math.max(bp[j], 1);
            int r = Math.min(bp[j + 1] - 1, n);
            if (l > r) continue;
            int mask = 0;
            for (int i = 0; i < m; i++)
                if (st[i] <= l && l <= ed[i])
                    mask |= (1 << i);
            covmask.add(mask);
        }

        long ans = 0;
        for (int mask = 0; mask < (1 << m); mask++) {
            boolean ok = true;
            for (int cm : covmask) {
                if (Integer.bitCount(cm & mask) < 2) { ok = false; break; }
            }
            if (ok) ans++;
        }
        System.out.println(ans % 998244353);
    }
}
import sys
input = sys.stdin.readline

def main():
    n, m = map(int, input().split())
    segs = []
    for _ in range(m):
        l, r = map(int, input().split())
        segs.append((l, r))

    bp = sorted(set([1, n + 1] + [l for l, r in segs] + [r + 1 for l, r in segs]))

    covmask = []
    for j in range(len(bp) - 1):
        l = max(bp[j], 1)
        r = min(bp[j + 1] - 1, n)
        if l > r:
            continue
        mask = 0
        for i, (sl, sr) in enumerate(segs):
            if sl <= l <= sr:
                mask |= (1 << i)
        covmask.append(mask)

    ans = 0
    for mask in range(1 << m):
        ok = True
        for cm in covmask:
            if bin(cm & mask).count('1') < 2:
                ok = False
                break
        if ok:
            ans += 1
    print(ans % 998244353)

main()
const readline = require('readline');
const rl = readline.createInterface({ input: process.stdin });
const lines = [];

rl.on('line', line => lines.push(line.trim()));
rl.on('close', () => {
    const [n, m] = lines[0].split(' ').map(Number);
    const segs = [];
    for (let i = 1; i <= m; i++) {
        const [l, r] = lines[i].split(' ').map(Number);
        segs.push([l, r]);
    }

    const bpSet = new Set([1, n + 1]);
    for (const [l, r] of segs) {
        bpSet.add(l);
        bpSet.add(r + 1);
    }
    const bp = [...bpSet].sort((a, b) => a - b);

    const covmask = [];
    for (let j = 0; j + 1 < bp.length; j++) {
        const l = Math.max(bp[j], 1);
        const r = Math.min(bp[j + 1] - 1, n);
        if (l > r) continue;
        let mask = 0;
        for (let i = 0; i < m; i++) {
            if (segs[i][0] <= l && l <= segs[i][1])
                mask |= (1 << i);
        }
        covmask.push(mask);
    }

    function popcount(x) {
        x = x - ((x >> 1) & 0x55555555);
        x = (x & 0x33333333) + ((x >> 2) & 0x33333333);
        return ((x + (x >> 4)) & 0x0f0f0f0f) * 0x01010101 >> 24;
    }

    let ans = 0;
    for (let mask = 0; mask < (1 << m); mask++) {
        let ok = true;
        for (const cm of covmask) {
            if (popcount(cm & mask) < 2) { ok = false; break; }
        }
        if (ok) ans++;
    }
    console.log(ans % 998244353);
});