ZKW线段树

应某迪要求,写一篇数据结构学习笔记。

实际上还没有学很多东西,只是一些基础的操作。

zkw线段树的学习资料,网上有很多,这里记录的只是自己的一些理解。

 

 

建树

1 inline void build(){
2     for(bit=1,n=read();bit<=n+1;bit<<=1);
3     for(int i=bit+1;i<=bit+n;++i) sum[i]=read();
4     for(int i=bit-1;i;--i) sum[i]=sum[i<<1]+sum[i<<1|1];
5 }

$zkw$线段树构造了一棵完美二叉树,只有最后一层叶子节点管辖的区间大小为1。

$zkw$线段树是基于位运算的,对于节点$p$,$p<<1$为它的左儿子,$p<<1|1$为它的右儿子。

因为是一棵完美二叉树,除掉叶子节点的部分一定为$2^k-1$的形式,将这个$2^k$记为$bit$,可以方便我们之后的操作。

其意义是,对于原序列的点$i$,可以直接得到对应线段树上的节点$i+bit$。

注意这里我们忽略了$bit$也就是$2^k$这一个节点,以后再提。

同时建树的一个细节是$bit$应当大于$n+1$,其原因也可以留到后面。

 

 

单点修改

1 inline void modify(int p,int val){
2     for(p+=bit;p;p>>=1) sum[p]+=val;
3 }

找到位置之后,直接修改一条祖先链。

 

 

区间修改

1 inline void modify(int l,int r,int val){
2     int lc=0,rc=0,len=1;
3     for(l+=bit-1,r+=bit+1;l^r^1;l>>=1,r>>=1,len<<=1){
4         sum[l]+=lc*val; sum[r]+=rc*val;
5         if(~l&1) sum[l^1]+=len*val,add[l^1]+=val,lc+=len;
6         if(r&1) sum[r^1]+=len*val,add[r^1]+=val,rc+=len;
7     }
8     for(;l;l>>=1,r>>=1) sum[l]+=lc*val,sum[r]+=rc*val;
9 }

$lc$:当前左指针包含的区间长度。$rc$:当前右指针包含的区间长度。$len$:当前翻到的节点层管辖区间的长度。

这里我们将$l$,$r$都作为开区间。所以分别加$bit-1$,$bit+1$处理。

因为操作是自下而上进行的,$zkw$线段树一般不维护懒标记,

因而我们用一个数组$add$进行标记永久化,表示这个区间的所有序列应该被加上这个值,显然这个值是不能下传的。

当$l$的最后一位为0,也就是说$l$指针为左儿子,那么l的右兄弟在当前修改的区间内。

同理$r$的左兄弟会在修改区间内。

当$l$,$r$两个指针已经成为兄弟,也就是说二者在二进制下只有最后一位不同,即异或值为1,那么全部的修改操作已经完成,可以结束。

然而祖先链上的$sum$值仍然需要修改。

这里可以解释,为什么$bit$应该大于$n+1$而不是$n$,为什么$bit+0$这个节点需要被空出来,因为我们需要开区间来进行操作。

然而似乎使$bit$仅保证大于$n$的打法是正确的,手玩确实没有错误。

 

 

单点查询

1 inline int query(int p){
2     int ans=0;
3     for(p+=bit,ans=sum[p],p>>=1;p;p>>=1) ans+=add[p];
4     return ans;
5 }

统计叶子节点的$sum$值,并不断加上祖先链的$add$标记即可。

应当注意的是不要加上叶子节点的$add$标记,这个标记是无意义的。

 

 

区间查询

 1 inline int query(int l,int r){
 2     int ans=0,lc=0,rc=0,len=1;
 3     for(l+=bit-1,r+=bit+1;l^r^1;l>>=1,r>>=1,len<<=1){
 4         ans+=add[l]*lc+add[r]*rc;
 5         if(~l&1) ans+=sum[l^1],lc+=len;
 6         if(r&1) ans+=sum[r^1],rc+=len;
 7     }
 8     for(;l;l>>=1,r>>=1) ans+=add[l]*lc+add[r]*rc;
 9     return ans;
10 }

区间查询的打法是类似于区间修改的。

首先将$l$,$r$设为开区间。

不断翻祖先链,记得加上兄弟节点整体的$sum$值和祖先链上部分的$add$标记就可以了。

应当注意的是循环中统计$add$标记和兄弟$sum$值的顺序不可交换,否则可能导致$lc$,$rc$变量维护的含义错误。

 

 

区间最值

思想大概是将儿子的最值不断差分到父亲身上。

因为不同的部分已经被差分掉,可以直接修改区间的最值。

应当注意的是,区间最值的求法与区间求和不同。

为了减少特判,原本的开区间被转化为闭区间。

但这样产生一个问题,如果查询区间长度为1会导致一些问题:左右端点永远不会成为兄弟,故导致了死循环。

所以要加一个单点查询的特判。

因为要维护区间最值,修改操作同时也要不断差分,于是打的麻烦了许多,代码可以参考下面。

 

 

基础操作

 1 #include<bits/stdc++.h>
 2 using namespace std;
 3 const int N=1e6+7;
 4 inline int read(register int x=0,register char ch=getchar(),register int f=0){
 5     while(!isdigit(ch)) f=ch=='-',ch=getchar();
 6     while(isdigit(ch)) x=(x<<1)+(x<<3)+(ch^48),ch=getchar();
 7     return f?-x:x;
 8 }
 9 int n,m,bit;
10 int sum[N<<2],add[N<<2],mn[N<<2],mx[N<<2];
11 inline void build(){
12     for(bit=1;bit<=n+1;bit<<=1);
13     for(int i=bit+1;i<=bit+n;++i) mx[i]=mn[i]=sum[i]=read();
14     for(int i=bit-1;i;--i){
15         sum[i]=sum[i<<1]+sum[i<<1|1];
16         mn[i]=min(mn[i<<1],mn[i<<1|1]); mn[i<<1]-=mn[i]; mn[i<<1|1]-=mn[i];
17         mx[i]=max(mx[i<<1],mx[i<<1|1]); mx[i<<1]-=mx[i]; mx[i<<1|1]-=mx[i];
18     }
19 }
20 inline int query(int p){
21     int ans=0;
22     for(p+=bit,ans=sum[p],p>>=1;p;p>>=1) ans+=add[p];
23     return ans;
24 }
25 inline int query(int l,int r){
26     int ans=0,lc=0,rc=0,len=1;
27     for(l+=bit-1,r+=bit+1;l^r^1;l>>=1,r>>=1,len<<=1){
28         ans+=add[l]*lc+add[r]*rc;
29         if(~l&1) ans+=sum[l^1],lc+=len;
30         if(r&1) ans+=sum[r^1],rc+=len;
31     }
32     for(;l;l>>=1,r>>=1) ans+=add[l]*lc+add[r]*rc;
33     return ans;
34 }
35 inline int query_min(int l,int r){
36     if(l==r) return query(l);
37     int lans=0,rans=0;
38     for(l+=bit,r+=bit;l^r^1;l>>=1,r>>=1){
39         lans+=mn[l]; rans+=mn[r];
40         if(~l&1) lans=min(lans,mn[l^1]);
41         if(r&1) rans=min(rans,mn[r^1]);
42     }
43     for(lans=min(lans+mn[l],rans+mn[r]),l>>=1;l;l>>=1) lans+=mn[l];
44     return lans;
45 }
46 inline int query_max(int l,int r){
47     if(l==r) return query(l);
48     int lans=0,rans=0;
49     for(l+=bit,r+=bit;l^r^1;l>>=1,r>>=1){
50         lans+=mx[l]; rans+=mx[r];
51         if(~l&1) lans=max(lans,mx[l^1]);
52         if(r&1) rans=max(rans,mx[r^1]);
53     }
54     for(lans=max(lans+mx[l],rans+mx[r]),l>>=1;l;l>>=1) lans+=mx[l];
55     return lans;
56 }
57 inline void modify(int l,int r,int val){
58     int lc=0,rc=0,len=1,x;
59     for(l+=bit-1,r+=bit+1;l^r^1;l>>=1,r>>=1,len<<=1){
60         sum[l]+=lc*val; sum[r]+=rc*val;
61         if(~l&1) sum[l^1]+=len*val,add[l^1]+=val,mn[l^1]+=val,lc+=len;
62         if(r&1) sum[r^1]+=len*val,add[r^1]+=val,mn[r^1]+=val,rc+=len;
63         x=min(mn[l],mn[l^1]); mn[l]-=x; mn[l^1]-=x; mn[l>>1]+=x;
64         x=min(mn[r],mn[r^1]); mn[r]-=x; mn[r^1]-=x; mn[r>>1]+=x;
65     }
66     for(;l;l>>=1,r>>=1){
67         sum[l]+=lc*val; sum[r]+=rc*val;
68         x=min(mn[l],mn[l^1]); mn[l]-=x; mn[l^1]-=x; mn[l>>1]+=x;
69         x=max(mx[l],mx[l^1]); mx[l]-=x; mx[l^1]-=x; mx[l>>1]+=x;
70     }
71 }
72 inline void modify(int p,int val){
73     int x;
74     for(p+=bit;p;p>>=1){
75         sum[p]+=val; mn[p]+=val; mx[p]+=val;
76         x=min(mn[p],mn[p^1]); mn[p]-=x; mn[p^1]-=x; mn[p>>1]+=x;
77         x=max(mx[p],mx[p^1]); mx[p]-=x; mx[p^1]-=x; mx[p>>1]+=x;
78     }
79 }
80 int main(){
81     n=read();
82     build();
83     return 0;
84 }

 

 

区间信息合并(山海经)

思想大概与区间查询最值一致。

为了减少特判将开区间转化为闭区间。

需要注意的是信息合并有左右的先后顺序。

所以左右指针的写法并不相同,最后将$l$,$r$扫过的信息合并就可以了。

 1 #include<bits/stdc++.h>
 2 const int N=100010;
 3 inline int read(register int x=0,register char ch=getchar(),bool f=0){
 4     while(!isdigit(ch)) f=ch=='-',ch=getchar();
 5     while(isdigit(ch)) x=(x<<1)+(x<<3)+(ch^48),ch=getchar();
 6     return f?-x:x;
 7 }
 8 struct Ans{
 9     int l,r,val;
10 };
11 struct Node{
12     Ans la,ra,mx,tot;
13 }s[N<<2];
14 int n,m,bit;
15 inline bool operator <(const Ans &a,const Ans &b){
16     return a.val<b.val||(a.val==b.val&&a.l>b.l)||(a.val==b.val&&a.l==b.l&&a.r>b.r);
17 }
18 inline Ans operator +(const Ans &a,const Ans &b){
19     return (Ans){a.l,b.r,a.val+b.val};
20 }
21 inline Node operator +(const Node &a,const Node &b){
22     return (Node){std::max(a.la,a.tot+b.la),std::max(b.ra,a.ra+b.tot),std::max(std::max(a.mx,b.mx),a.ra+b.la),a.tot+b.tot};
23 }
24 void build(){
25     for(bit=1;bit<=n+1;bit<<=1);
26     for(int i=bit+1;i<=bit+n;++i) s[i].tot=s[i].mx=s[i].la=s[i].ra=(Ans){i-bit,i-bit,read()};
27     for(int i=bit-1;i;--i) s[i]=s[i<<1]+s[i<<1|1];
28 }
29 Ans query(int r,int l){
30     if(l==r) return s[l+bit].mx;
31     Node L=s[l+bit],R=s[r+bit];
32     for(l+=bit,r+=bit;l^r^1;l>>=1,r>>=1){
33         if(~l&1) L=L+s[l^1];
34         if(r&1) R=s[r^1]+R;
35     }
36     return (L+R).mx;
37 }
38 void print(const Ans &x){
39     printf("%d %d %d\n",x.l,x.r,x.val);
40 }
41 int main(){
42     n=read(); m=read(); build();
43     for(int i=1;i<=m;++i) print(query(read(),read()));
44     return 0;
45 }