浅谈可持久化线段树--主席树
权值线段树
权值线段树和普通线段树不一样的地方就是在于 它的结点存储的是区间内数的个数
这个线段树的好处就在于我们可以根据 左子树 和 右子树 的大小从而进行 查找某个数的排名 或者 查找排名为rk的数
可持久化的含义
可持久数据结构主要指的是我们可以查询历史版本的情况并支持插入,利用使用之前历史版本的数据结构来减少对空间的消耗(能够对历史进行修改的是函数式)。
主席树的建树过程:
最开始的时候就是一个空树
然后我们再插入一个元素3
再加入一个元素 1
模版一 :求区间第K大
我们考虑查询。
例如我们插入: 1 5 2 6 3 7 4
要查询[2, 5]中第3大的数我们首先把第1棵线段树和第5棵拿出来。
根据前面说的插入操作 我们最终得到的线段树是这样的:
要查询[2, 5]中第3大的数我们首先把第1棵线段树和第5棵拿出来。
然后我们发现,将对应节点的数相减,刚刚好就是[2, 5]内某个范围内的数的个数。比如[1, 4]这个节点相减是2,就说明[2. 5]内有2个数是在1~4范围内(就是2, 3)。
所以对于一个区间[l, r],我们可以每次算出在[l, mid]范围内的数,如果数量>=k(k就是第k大),就往左子树走,否则就往右子树走。
1 #include <stdio.h> 2 #include <iostream> 3 #include <algorithm> 4 #include <string.h> 5 #include <vector> 6 #include <random> 7 8 const int maxn = 2e5 + 10; 9 10 int ls[maxn<<5],rs[maxn<<5],sum[maxn<<5],rt[maxn<<5]; 11 int cnt; 12 13 void init() { 14 memset(sum,0, sizeof(sum)); 15 cnt = 0; 16 } 17 int build(int l, int r){ 18 int root = ++ cnt; 19 if(l == r) return root; 20 int mid = (l + r) >> 1; 21 ls[root] = build(l, mid); 22 rs[root] = build(mid + 1, r); 23 return root; 24 } 25 int update(int k, int l, int r, int root){ 26 int id = ++ cnt; 27 ls[id] = ls[root]; rs[id] = rs[root]; sum[id] = sum[root] + 1; 28 if(l == r) return id; 29 int mid = (l + r) >> 1; 30 if(k <= mid) ls[id] = update(k, l, mid, ls[id]); 31 if(k > mid) rs[id] = update(k, mid + 1, r, rs[id]); 32 return id; 33 } 34 int query(int k, int u, int v, int l, int r) 35 { 36 int mid = (l + r) >> 1; 37 int x = sum[ls[v]] - sum[ls[u]]; 38 if(l == r) return l; 39 if(k <= x) return query(k, ls[u], ls[v], l, mid); 40 if(k > x) return query(k - x, rs[u], rs[v], mid + 1, r); 41 } 42 43 std::vector<int> v; 44 int getid(int x) { 45 return lower_bound(v.begin(),v.end(),x)-v.begin()+1; 46 } 47 48 int arr[maxn]; 49 int main() { 50 int n,m; 51 scanf("%d%d",&n,&m); 52 for (int i=1;i<=n;i++) { 53 scanf("%d",&arr[i]); 54 v.push_back(arr[i]); 55 } 56 std::sort(v.begin(),v.end()); 57 v.erase(std::unique(v.begin(),v.end()),v.end()); 58 int len = v.size(); 59 rt[0] = build(1,len); 60 for (int i=1;i<=n;i++) { 61 rt[i] = update(getid(arr[i]),1,len,rt[i-1]); 62 } 63 while (m--) { 64 int l,r,k; 65 scanf("%d%d%d",&l,&r,&k); 66 printf("%d\n",v[query(k,rt[l-1],rt[r],1,len)-1]); 67 } 68 }
模版二: 求区间内小于等于 k 的个数
我们只需要加上一个 ask 函数就好了
1 int ask(int k,int u,int v,int l,int r) { 2 if (l == r) { 3 return sum[v] - sum[u]; 4 } 5 int mid = (l + r ) >> 1; 6 if (k <= mid) 7 return ask(k,ls[u],ls[v],l,mid); 8 else { 9 int ret = 0; 10 ret += sum[ls[v]] - sum[ls[u]]; 11 ret += ask(k,rs[u],rs[v],mid+1,r); 12 return ret; 13 } 14 }
模版三: 求区间内不同数的个数
1 #include <stdio.h> 2 #include <iostream> 3 #include <algorithm> 4 #include <string.h> 5 #include <vector> 6 #include <map> 7 #include <random> 8 9 const int maxn = 3e5 + 10; 10 11 int arr[maxn],nxt[maxn]; 12 int ls[maxn<<5],rs[maxn<<5],sum[maxn<<5],rt[maxn<<5]; 13 int cnt; 14 15 std::map<int,int> mp; 16 17 void init() { 18 memset(sum,0, sizeof(sum)); 19 cnt = 0; 20 } 21 int build(int l, int r){ 22 int root = ++ cnt; 23 if(l == r) return root; 24 int mid = (l + r) >> 1; 25 ls[root] = build(l, mid); 26 rs[root] = build(mid + 1, r); 27 return root; 28 } 29 int update(int k, int l, int r, int root){ 30 int id = ++ cnt; 31 ls[id] = ls[root]; rs[id] = rs[root]; sum[id] = sum[root] + 1; 32 if(l == r) return id; 33 int mid = (l + r) >> 1; 34 if(k <= mid) ls[id] = update(k, l, mid, ls[id]); 35 if(k > mid) rs[id] = update(k, mid + 1, r, rs[id]); 36 return id; 37 } 38 int query(int u,int v,int l,int r,int xx,int yy) { 39 if (xx <= l && yy>=r) 40 return sum[v] - sum[u]; 41 int mid = (l + r) >> 1; 42 int res = 0; 43 if (xx <= mid) { 44 res += query(ls[u],ls[v],l,mid,xx,yy); 45 } 46 if (yy > mid) { 47 res += query(rs[u],rs[v],mid+1,r,xx,yy); 48 } 49 return res; 50 } 51 52 53 std::vector<int> v; 54 int getid(int x) { 55 return lower_bound(v.begin(),v.end(),x)-v.begin()+1; 56 } 57 58 int main() { 59 int n; 60 scanf("%d",&n); 61 for (int i=1;i<=n;i++) { 62 scanf("%d",&arr[i]); 63 } 64 mp.clear(); 65 for (int i=n;i>=1;i--) { 66 if (mp[arr[i]] == 0) 67 nxt[i] = n+1; 68 else 69 nxt[i] = mp[arr[i]]; 70 mp[arr[i]] = i; 71 } 72 rt[0] = build(1,n+1); 73 for (int i=1;i<=n;i++) { 74 rt[i] = update(nxt[i],1,n+1,rt[i-1]); 75 } 76 int m,x,y; 77 scanf("%d",&m); 78 while (m--) { 79 scanf("%d%d",&x,&y); 80 printf("%d\n",query(rt[x-1],rt[y],1,n+1,y+1,n+1)); 81 } 82 return 0; 83 }
模版四: 求树上点权第k大
1 #include<cstdio> 2 #include <stdio.h> 3 #include <iostream> 4 #include <algorithm> 5 #include <string> 6 #include <string.h> 7 #include <vector> 8 #include <map> 9 #include <random> 10 11 const int maxt = 2e7 + 5; 12 const int maxn = 1e5 + 5; 13 14 int n, _n, m, a[maxn], b[maxn]; 15 16 struct Edge { 17 int net,to; 18 }e[maxn<<1]; 19 20 int head[maxn],ecnt = -1; 21 22 void add_edge(int x,int y) { 23 e[++ecnt] = (Edge){head[x],y}; 24 head[x] = ecnt; 25 } 26 27 int ls[maxt],rs[maxt],sum[maxt],rt[maxt]; 28 int cnt; 29 30 void init() { 31 memset(sum,0, sizeof(sum)); 32 cnt = 0; 33 } 34 void update(int old,int &now,int l,int r,int id) { 35 now = ++cnt; 36 ls[now] = ls[old]; rs[now] = rs[old]; sum[now] = sum[old] + 1; 37 if (l == r) 38 return; 39 int mid = (l + r) >> 1; 40 if (id <= mid) 41 update(ls[old],ls[now],l,mid,id); 42 else 43 update(rs[old],rs[now],mid+1,r,id); 44 } 45 int query(int f,int z,int x,int y,int l,int r,int k) { 46 if (l == r) 47 return l; 48 int mid = (l + r ) >> 1,Sum = sum[ls[x]] + sum[ls[y]] - sum[ls[z]] - sum[ls[f]]; 49 if (k <= Sum) { 50 return query(ls[f],ls[z],ls[x],ls[y],l,mid,k); 51 } 52 else 53 return query(rs[f],rs[z],rs[x],rs[y],mid+1,r,k-Sum); 54 } 55 56 int dep[maxn],fa[21][maxn]; 57 void dfs(int now,int _f) { 58 update(rt[_f],rt[now],1,_n,a[now]); 59 for(int i = 1; (1 << i) <= dep[now]; ++i) 60 fa[i][now] = fa[i - 1][fa[i - 1][now]]; 61 for(int i = head[now], v; i != -1; i = e[i].net) 62 { 63 if((v = e[i].to) == _f) continue; 64 dep[v] = dep[now] + 1; 65 fa[0][v] = now; 66 dfs(v, now); 67 } 68 } 69 70 int lca(int x, int y) 71 { 72 if(dep[x] < dep[y]) 73 std::swap(x, y); 74 for(int i = 20; i >= 0; --i) 75 if(dep[x] - (1 << i) >= dep[y]) x = fa[i][x]; 76 if(x == y) return x; 77 for(int i = 20; i >= 0; --i) 78 if(fa[i][x] != fa[i][y]) x = fa[i][x], y = fa[i][y]; 79 return fa[0][x]; 80 } 81 82 int main() 83 { 84 init(); 85 memset(head, -1, sizeof(head)); 86 scanf("%d%d",&n,&m); 87 for(int i = 1; i <= n; ++i) { 88 scanf("%d", &a[i]); 89 b[i] = a[i]; 90 } 91 std::sort(b + 1, b + n + 1); 92 _n = (std::unique(b + 1, b + n + 1) - b - 1); 93 for(int i = 1; i <= n; ++i) 94 a[i] = std::lower_bound(b + 1, b + _n + 1, a[i]) - b; 95 for(int i = 1; i < n; ++i) 96 { 97 int x,y; 98 scanf("%d%d",&x,&y); 99 add_edge(x, y); add_edge(y, x); 100 } 101 dfs(1, 0); 102 for(int i = 1; i <= m; ++i) 103 { 104 int x,y,k; 105 scanf("%d%d%d",&x,&y,&k); 106 int z = lca(x, y); 107 printf("%d\n",b[query(rt[fa[0][z]], rt[z], rt[x], rt[y], 1, _n, k)]); 108 } 109 return 0; 110 }