思路1

差分,定义区间[a, b]为砍坐标a-b之间树的次数,这样得到的差分前缀和sums[i]为0时,表示当前坐标i的树未被移走,加入到答案中。

#include <iostream>
#include <cstring>
#include <algorithm>

using namespace std;

const int N = 10010;

int n, m;
int a[N];

int main() {
    scanf("%d%d", &n, &m);
    memset(a, 0, sizeof a);
    int ans = 0;
    while (m -- ) {
        int l, r;
        scanf("%d%d", &l, &r);
        if (l > r) swap(l ,r);
        a[l] ++ , a[r + 1] -- ;
    }
    for (int i = 0; i <= n; i ++ ) {
        if (i) a[i] += a[i - 1];
        if (!a[i]) ans ++ ;
    }
    printf("%d\n", ans);
    return 0;
}

思路2

贪心,将所有区间合并,这样就可以直接遍历合并的区间得到移走树的数目nums,答案就是n - nums + 1;

#include <iostream>
#include <cstring>
#include <algorithm>

using namespace std;

struct Node {
    int l, r;
    
    bool operator< (const Node &other)const {
        return l < other.l;
    }
};
Node a[100010];

int n, m;

int main() {
    scanf("%d%d", &n, &m);
    for (int i = 0; i < m; i ++ ) {
        int l, r;
        scanf("%d%d", &l, &r);
        if (l > r) swap(l, r);
        a[i].l = l, a[i].r = r;
    }
    sort(a, a + m);
    int l = -1, r = -1;
    int len = 0;
    for (int i = 0; i < m; i ++ ) {
        if (a[i].l > r) {
            if (l != -1 && r != -1)
                len += r - l + 1;
            l = a[i].l, r = a[i].r;
        } else 
            r = max(r, a[i].r);
    }
    if (l != -1 && r != -1) 
        len += r - l + 1;
    printf("%d\n", n - len + 1);
    return 0;
}

思路3

对于差分,假如题目范围n <= 1e9, 直接开数组会MLE,这时用上离散化的思想。只用上区间端点,对于其他的点没必要浪费空间和时间。同时维护变量前缀和sums = 0,当sums:0->1, 且当前位置差分值为1时,说明当前位置前有一段区间是未被移走的树。

#include <iostream>
#include <cstring>
#include <algorithm>
#include <vector>

using namespace std;

const int N = 100010;

int n, m;

struct Node {
    int pos;
    int sums;
    
    bool operator< (const Node& other)const {
        if (pos == other.pos) return sums < other.sums;
        return pos < other.pos;
    }
};

vector<Node> a;

int main() {
    scanf("%d%d", &n, &m);
    while (m -- ) {
        int l, r;
        scanf("%d%d", &l, &r);
        if (l > r) swap(l, r);
        a.push_back({l, 1});
        a.push_back({r + 1, -1});
    }
    sort(a.begin(), a.end());
    int len = 0, j = 0;
    for (int i = 0, sums = 0; i < a.size(); i ++ ) {
        sums += a[i].sums;
        if (sums == 1 && a[i].sums == 1)
            len += a[i].pos - j;
        if (sums == 0) 
            j = a[i].pos;
    }
    len += n - a.back().pos + 1;
    printf("%d\n", len);
    return 0;
}