背景
2019牛客杭电多校都把线段树当做最最最基础的知识点,杭电3甚至把线段树当做签到,所以线段树要多练啊
题目
算法
线段树入门题 O(Mlog(N))
一步步思考线段树过程
关键点
不能简单的对比对比左右子区间的dat和值(区间最大和值)来更新本节点的区间最大和值,还要对比右子树的rmax+左子树的lmax的和值
c.dat=max(max(a.dat,b.dat),a.rmax+b.lmax);
附图可以参考,关键看代码注释
手写AC代码
#include<bits/stdc++.h> using namespace std; #define ll long long const int SIZE = 5e5+7; struct SegmentTree{ int l,r; int lmax,rmax,sum; int dat; } t[SIZE<<2]; int a[SIZE],N,M; void pushup(int p){ t[p].sum = t[p*2].sum + t[p*2+1].sum; t[p].lmax = max(t[p*2].lmax,t[p*2].sum+t[p*2+1].lmax); t[p].rmax = max(t[p*2+1].rmax,t[p*2+1].sum+t[p*2].rmax); t[p].dat = max(t[p*2].dat,max(t[p*2+1].dat,t[p*2].rmax+t[p*2+1].lmax)); } void build(int p,int l,int r){ t[p].l=l,t[p].r=r; if(l==r){ t[p].sum=t[p].lmax=t[p].rmax=t[p].dat=a[l]; return ; } int mid = (l+r)/2; build(p*2,l,mid); build(p*2+1,mid+1,r); pushup(p); } void change(int p,int x,int v){ if(t[p].l==t[p].r){t[p].dat=t[p].sum=t[p].lmax=t[p].rmax=v;return ;} int mid = (t[p].l+t[p].r)/2; if(x<=mid) change(p<<1,x,v); else change(p<<1|1,x,v); pushup(p); } // int ask(int p,int l,int r){ // if(l <= t[p].l && r >= t[p].r) return t[p].dat; // int mid = (t[p].l + t[p].r)/2; // int val = -(1<<30); // if(l<=mid) val = max(val,ask(p<<1,l,r)); // if(r>mid) val = max(val,ask(p<<1|1,l,r)); // return val; // } /* 此处感谢队友帮忙debug */ SegmentTree ask(int p,int l,int r){ if (l<=t[p].l && r>=t[p].r) return t[p]; int mid=(t[p].l+t[p].r)>>1; int val=-(1<<30); SegmentTree a,b,c; a.dat=a.sum=a.lmax=a.rmax=val; b.dat=b.sum=b.lmax=b.rmax=val; c.dat=c.lmax=c.rmax=val; c.sum=0; /* 要么都在最左边,要么都在最右边,要么跨越了左右,跨越了左右就要判断 c.dat=max(max(a.dat,b.dat),a.rmax+b.lmax); 这个点没想到会卡住 --> 我就是从前面的注释的ask函数出错来的*/ if (l<=mid&&r<=mid){ a=ask(p<<1,l,r); c.sum+=a.sum; } /* 还要注意左区间r<=mid,右区间l>mid */ // else if (l>=mid&&r>=mid){ else if (l>mid&&r>mid){ b=ask(p*2+1,l,r); c.sum+=b.sum; } else{ a=ask(p<<1,l,r); b=ask(p*2+1,l,r); c.sum+=a.sum+b.sum; } c.dat=max(c.dat,max(max(a.dat,b.dat),a.rmax+b.lmax)); c.lmax=max(c.lmax,max(a.lmax,a.sum+b.lmax)); c.rmax=max(c.rmax,max(b.rmax,b.sum+a.rmax)); return c; } int main(){ ios::sync_with_stdio(false);cin.tie(0); cin>>N>>M; for(int i=1;i<=N;i++) cin>>a[i]; build(1,1,N); int i,x,y; while(M--){ cin>>i>>x>>y; if(i==1){ if(x>y) swap(x,y); cout << ask(1, x, y).dat << endl; } else change(1,x,y); // for(int i=1;i<=9;i++){ // cout<<"dat: "<<t[i].dat<<" sum: "<<t[i].sum<<" lmax: "<<t[i].lmax<<" rmax: "<<t[i].rmax<<endl; // } } return 0; }