Can you answer these queries III

  • 区间查询,查询区间 [x,y] 中的最大连续子段和
  • 单点修改,A[x] 改成 y
  • 维护三个属性lmax前缀最大和,rmax后缀最大和,tmax最大连续子段和,sum整段和
      u.sum = l.sum + r.sum;//父节点整段和 = 左儿子整段和+右儿子整段和
      u.lmax = max(l.lmax,l.sum + r.lmax);//父节点前缀最大值 = max(左儿子前缀最大值,左儿子整段和+右儿子前缀最大值)
      u.rmax = max(r.rmax,r.sum + l.rmax);//与上面同理
      u.tmax = max(max(l.tmax,r.tmax),l.rmax + r.lmax);//最大连续子段和=max(左儿子连续子段和最大值,右儿子连续子段和最大值,左儿子后缀+右儿子前缀)

代码如下:

#include<bits/stdc++.h>

using namespace std;

#define  mm(a,x) memset(a,x,sizeof a)
#define  mk make_pair
#define ll long long
#define pii pair<int,int>
#define inf 0x3f3f3f3f
#define lowbit(x) (x) & (-x)

const int N = 5e5 + 10;

int n,m;
int a[N];
struct Node{
    int l,r;
    int sum,lmax,rmax,tmax;
}tr[N << 2];

void pushup(Node &u,Node &l,Node &r){
    u.sum = l.sum + r.sum;
    u.lmax = max(l.lmax,l.sum + r.lmax);
    u.rmax = max(r.rmax,r.sum + l.rmax);
    u.tmax = max(max(l.tmax,r.tmax),l.rmax + r.lmax);
}

void pushup(int u){
    pushup(tr[u],tr[u << 1],tr[u << 1| 1]);
}

void build(int u,int l,int r){
    if(l == r) tr[u] = {l,r,a[l],a[l],a[l],a[l]};
    else{
        tr[u] = {l,r};
        int mid = (l + r) >> 1;
        build(u << 1,l,mid);build(u << 1 | 1,mid + 1,r);
        pushup(u);        
    }
}

void modify(int u,int x,int v){
    if(tr[u].l == x && tr[u].r == x) tr[u] = {x,x,v,v,v,v};
    else{
        int mid = (tr[u].l + tr[u].r) >> 1;
        if(x <= mid) modify(u << 1,x,v);
        else modify(u << 1 | 1,x,v);
        pushup(u); 
    }
}

Node query(int u,int l,int r){
    if(tr[u].l >= l && tr[u].r <= r) return tr[u];
    else{
        int mid = (tr[u].l + tr[u].r) >> 1;
        if(r <= mid) return query(u << 1,l,r); //直接返回
        else if(l > mid) return query(u << 1 | 1,l,r); //直接返回
        else{ //该计算父节点的最大连续子段和
            auto left = query(u << 1,l,r);
            auto right = query(u << 1 | 1,l,r);
            Node res;
            pushup(res,left,right);
            return res;
        }
    }
}

int main() {
    scanf("%d%d",&n,&m);
    for(int i = 1; i <= n; i ++ ) scanf("%d",&a[i]);
    build(1,1,n);
    int k,x,y;
    while(m -- ){
        scanf("%d %d %d",&k,&x,&y);
        if(k == 1){
            if(x > y) swap(x,y);
            printf("%d\n",query(1,x,y).tmax);
        }else{
            modify(1,x,y);
        }
    }
    return 0;
}