主席树又叫可持久化线段树

  为了实现可持久化这一目的,主席树在建树或者更新的时候会建立多个历史版本,以便于在之后的查询可以随时回到某个历史版本
  建立历史版本的方法有多种,但是为了实现空间和时间上的最优,我们采用下面的方法建立历史版本
  假定我们现在有一个问题,给你一个长度n为的序列data,和一个序列aux。每一时刻按从左往右的顺序从data数组中拿一个数加入aux数组。求任意一个时刻序列aux中[a,b]范围内偶数的个数
  很容易想到用线段树维护区间[a,b]的偶数的个数。题目上又说可是任意一个时刻,所以我们要实现回到某个历史版本的功能。
  为了方便描述主席树的步骤,我们先对本篇博客所说的问题进行具体化;
   我们规定n为12,data里面的数分别为{3,8,13,15,5,8,12,9,3,6,21,22};有6个查询,每一行有三个数,分别代表时刻,左端点,右端点。

1,4,6
4,7,9
6,7,7
3,1,2
9,1,7
4,1,9

这是一个经典的主席树问题,一般的主席树(本片博客所说的主席树是指静态主席树)解题步骤分为三步:
  1.离散化:

首先我们将区间离散化,离散化的目的为了节省空间,也为后面的步骤优化了代码。那么data{3,8,13,15,5,8,12,9,3,6,21,22} -> hash{1,4,7,8,2,4,6,5,1,3,9,10}

  2.建树并建立历史版本:

刚开的时候是一颗空树,如图

第1时刻的时候,把3离散化后对应的值1加入线段树,因为3是奇数,所以更新的节点的值 + 0,如图

第2时刻的时候,把8离散化后对应的值4加入线段树,因为8是偶数,所以更新的节点的值 + 1,如图

第3时刻的时候,把13离散化后对应的值7加入线段树,因为13是奇数,所以更新的节点的值 + 0, 如图:

可以看出每次更新的时间复杂度与空间复杂度都是log(n)
这个结论跟容易推出来,因为每次向下更新的时候,只有向左更新或者向右更新。假如我们要向右边更新,我们就往右边建一个
新节点,并使当前节点的右指针指向这个新节点,左指针还指向本节点原来的左儿子

  3.查询

跟普通的线段树,这里就不再赘述

接下来我们来看一个例题,也是一个经典的主席树问题,求区间第k小

在讲解之前,我们思考一个问题,就是普通的线段树如何解决区间第k小的问题
例如区间[1,5] = 3。代表在区间1~5之间,出现了3个范围在[1,5]的数。
那么我们该如何查找呢。
例如我们要查找长度为len且最小值为x,最大值为y的数组中区间下标[a,b]之间第k小的数,
那么我们能求出在[a,b]中范围在[x,(x + y) / 2]的数有cnt个,我们比较cnt与k的关系,来判断下一个查找区间在左边还是右边如果cnt小于等于k,那么说明最终答案一定在[x,(x+y)/2]里面,如果cnt比k大,说明答案一定在[(x+y) / 2 + 1, y]里面。然后往下查找

整个查找过程用了二分的思想
另外,这个题其实查找的过程是比较难理解的,其他的都很好理解。为了能得到区间[a,b]的情况,我们需要利用前缀和的思想,也就是区间[a,b]的情况可以用[l,b] - [l,a-1]得到。

代码如下:

#include<stdio.h>
#include<string.h>
#include<algorithm>
#define ll long long 
#define mmset(a,b) memset(a,b,sizeof(a))
#define hash hash1
using namespace std;
const int N = 1e5 + 5;
int ls[N * 30],rs[N * 30], ts[N * 30], root[N];
int tree[N * 30];
int data[N], hash[N];
int n,m,tot;
void build(int l,int r, int& rt)
{
	rt = tot++;
	tree[rt] = 0;
	if(l == r)
	{
		return;
	}
	else
	{
		int m = (l + r) >> 1;
		build(l,m,ls[rt]);
		build(m+1,r,rs[rt]);
	}
	
}
void add(int p, int C, int l, int r, int& rt, int last)
{
	rt = tot++;
	ls[rt] = ls[last];
	rs[rt] = rs[last];
	tree[rt] = tree[last] + C;
	if(l == r)
	{
		return;
	}
	else
	{
		int m = (l + r) >> 1;
		if(p <= m)
		{
			add(p,C,l,m,ls[rt],ls[rt]);
		}
		if(p >= m + 1)
		{
			add(p,C,m + 1, r, rs[rt], rs[rt]);
		}
	}
	
}
int query(int x, int y, int l, int r, int k)
{
	if(l == r)
	{
		return l;
	}
	else
	{
		int m = (l + r) >> 1;
		int cnt = tree[ls[y]] - tree[ls[x]];
		if(k <= cnt)
		{
			return query(ls[x],ls[y],l,m,k);
		}
		else
		{
			return query(rs[x],rs[y],m+1,r,k - cnt);
		}
	}
	

}

/*
7 3
1 5 2 6 3 7 4
2 5 3
4 4 1
1 7 3
Sample Output
5
6
3
 */
int main()
{
	tot = 1;
	scanf("%d %d",&n,&m);
	for(int i = 1; i <= n; i++)
	{
		scanf("%d",&data[i]);
		hash[i] = data[i];
	}
	sort(hash + 1, hash + 1 + n);
	int len = unique(hash + 1, hash + 1 + n) -  (hash + 1);
	build(1,len,root[0]);
	for(int i = 1; i <= n; i++)
	{
		int a = lower_bound(hash + 1, hash + 1 + n, data[i]) - hash;
		add(a,1,1,len,root[i],root[i-1]);
	}
	for(int i = 1; i <= m; i++)
	{
		int a,b,k;
		scanf("%d %d %d",&a,&b,&k);
		int res = hash[query(root[a-1],root[b],1,len,k)];
		printf("%d\n",res);
	}
	return 0;
}