在找题解的时候发现了一篇很不错的树状数组详解,放在这里与大家共享 文章地址

树状数组:

树状数组是一个查询和修改复杂度都为log(n)的数据结构,假设数组a[1..n],

 

用lowbit函数维护了一个树的结构

那么查询a[1]+...+a[n]的时间是log级别的,而且是一个在线的数据结构,

  支持随时修改某个元素的值,复杂度也为log级别。

  来观察这个图:

  令这棵树的结点编号为C1,C2...Cn。令每个结点的值为这棵树的值的总和,那么容易发现:

  C1 = A1

  C2 = A1 + A2

  C3 = A3

  C4 = A1 + A2 + A3 + A4

  C5 = A5

  C6 = A5 + A6

  C7 = A7

  C8 = A1 + A2 + A3 + A4 + A5 + A6 + A7 + A8

  ...

  C16 = A1 + A2 + A3 + A4 + A5 + A6 + A7 + A8 + A9 + A10 + A11 + A12 + A13 + A14 + A15 + A16

  这里有一个有趣的性质:

  设节点编号为x,那么这个节点管辖的区间为2^k(其中k为x二进制末尾0的个数)个元素。因为这个区间最后一个元素必然为Ax,

  所以很明显:Cn = A(n – 2^k + 1) + ... + An

  算这个2^k有一个快捷的办法,定义一个函数如下即可:

  int lowbit(int x){

  return x&(x^(x–1));

  }

  当想要查询一个SUM(n)(求a[n]的和),可以依据如下算法即可:

  step1: 令sum = 0,转第二步;

  step2: 假如n <= 0,算法结束,返回sum值,否则sum = sum + Cn,转第三步;

  step3: 令n = n – lowbit(n),转第二步。

  可以看出,这个算法就是将这一个个区间的和全部加起来,为什么是效率是log(n)的呢?以下给出证明:

  n = n – lowbit(n)这一步实际上等价于将n的二进制的最后一个1减去。而n的二进制里最多有log(n)个1,所以查询效率是log(n)的。

  那么修改呢,修改一个节点,必须修改其所有祖先,最坏情况下为修改第一个元素,最多有log(n)的祖先。

  所以修改算法如下(给某个结点i加上x):

  step1: 当i > n时,算法结束,否则转第二步;

  step2: Ci = Ci + x, i = i + lowbit(i)转第一步。

  i = i +lowbit(i)这个过程实际上也只是一个把末尾1补为0的过程。

  对于数组求和来说树状数组简直太快了!

 

代码:

#include <iostream>  
#include<cstring>
#include <algorithm>  
using namespace std;  
int n, b[100005], val[100005];  
struct ac  
{  
    int x, y, id;  
}a[100005];  
bool cmp(ac a,ac b)   //对数据进行排序  
{  
    if(a.y != b.y) 
		return a.y > b.y;    //先由大到小排y  
    return a.x < b.x;   //y相同时,x由小到大排  
}   
int lowbit(int i)  
{  
    return i&(-i);  
}   
void update(int i, int x)  
{  
    while(i <= 100005)  
    {  
        b[i] += x;  
        i += lowbit(i);  
    }  
}   
int sum(int i)  
{  
    int sum = 0;  
    while(i > 0)  
    {  
        sum += b[i];  
        i -= lowbit(i);  
    }  
    return sum;  
}  
  
int main()  
{  
    int i;  
    while(cin>>n,n)  
    {  
        memset(b, 0, sizeof(b));  
        memset(val, 0, sizeof(val));  
        for(i = 0; i < n; i++)  
        {  
            scanf("%d %d", &a[i].x, &a[i].y);  
            a[i].id = i;  
            a[i].x++; 				//x与y有可能为0,所以都++
			a[i].y++;   
        }  
        sort(a, a+n, cmp);  
        val[a[0].id] = sum(a[0].x); //val[i]表示在i点有多少大于i的范围的点 
        update(a[0].x, 1);
        for(i = 1; i < n; i++)  
        {  
            if(a[i].x == a[i-1].x && a[i].y == a[i-1].y)  //两区间相等时,该点的val等于上一个点 
                val[a[i].id] = val[a[i-1].id];       
            else val[a[i].id] = sum(a[i].x);  
            update(a[i].x, 1);       //更新该点x值  
        }  
        cout<<val[i]; 
        for(i = 1; i < n; i++)  
        {  
            cout<<" "<<val[i];
        }  
        cout<<endl; 
    } 
    return 0;  
}