用差分加前缀和求出每个点会被几个区间覆盖,优先选择被覆盖多的点,树状数组维护每个区间内被选中点的数量,线段树维护这个区间中与其他区间交集最多的点是哪个,(当这个区间选完后需要将这个区间的影响删除,蒟蒻因为这个一直通过率0%,要用到懒标记)
#include<iostream> #include<stdio.h> #include<cstring> #include<algorithm> using namespace std; const int N = 5e5+10,M = 1e6+10; typedef long long ll; struct node{ int l,r; int add,h,cnt,h_id; }tr[4*N]; struct node1{ int l,r,k; }work[M]; int n,m; bool cmp(node1 a,node1 b) { if(a.r!=b.r) return a.r<b.r; if(a.l!=b.l) return a.l<b.l; return a.k>b.k; } int b[N]; void pushdown(int u) { auto &root = tr[u], &left = tr[u << 1], &right = tr[u << 1 | 1]; if (root.add) { left.add += root.add, left.h += root.add; right.add += root.add, right.h += root.add; root.add = 0; } } void pushup(node &root,node &left,node &right) { root.cnt=left.cnt+right.cnt; if(left.h>right.h) root.h=left.h,root.h_id=left.h_id; else root.h=right.h,root.h_id=right.h_id; } void pushup(int u) { pushup(tr[u],tr[u<<1],tr[u<<1|1]); } void build(int u,int l,int r) { tr[u]={l,r}; if(l==r) { tr[u].h=b[l],tr[u].h_id=l; return; } int mid = l+r >>1; build(u<<1,l,mid),build(u<<1|1,mid+1,r); pushup(u); } void modify(int u, int l, int r, int d) { if (tr[u].l >= l && tr[u].r <= r) { tr[u].h += d; tr[u].add += d; } else // 一定要分裂 { pushdown(u); int mid = tr[u].l + tr[u].r >> 1; if (l <= mid) modify(u << 1, l, r, d); if (r > mid) modify(u << 1 | 1, l, r, d); pushup(u); } } node query(int u,int l,int r) { if(l<=tr[u].l&&tr[u].r<=r) return tr[u]; else{ pushdown(u); int mid = tr[u].l+tr[u].r >>1; if(r<=mid) return query(u<<1,l,r); else if(l>mid) return query(u<<1|1,l,r); else{ auto left = query(u<<1,l,r); auto right = query(u<<1|1,l,r); node res; pushup(res,left,right); return res; } } } int sum[N]; int lowbit(int x){return x&-x;} void add(int x,int d){for(int i=x;i<=n;i+=lowbit(i)) sum[i]+=d;} int getsum(int x) {int ans=0;for(int i=x;i;i-=lowbit(i)) ans+=sum[i];return ans;} int main() { cin>>n>>m; for(int i=0;i<m;i++) { scanf("%d%d%d",&work[i].l,&work[i].r,&work[i].k); b[work[i].l]++,b[work[i].r+1]--; } for(int i=1;i<=n;i++) b[i]+=b[i-1]; build(1,1,n); ll ans=0; sort(work,work+m,cmp); for(int i=0;i<m;i++) { int t=getsum(work[i].r)-getsum(work[i].l-1); int cnt=work[i].k-t;if(cnt<0) cnt=0; ans+=cnt; for(int j=0;j<cnt;j++) { node t=query(1,work[i].l,work[i].r); add(t.h_id,1); modify(1,t.h_id,t.h_id,-t.h); } modify(1,work[i].l,work[i].r,-1); } cout<<ans<<endl; return 0; }