树状数组最经典的应用
树状数组最经典的应用
更简单一点的理解请参考《算法笔记》-树状数组(BIT) 相关章节
-
树状数组的模板
int lowbit(int x){ return x&-x; } int getSum(x){ int sum = 0; for(int i = x ; i > 0 ; i-=lowbit(i)){ sum += C[i]; } return sum } // 将第x个数加上v void update(int x,int v){ for(int i = x ; i < maxn ; i+=lowbit(i)){ C[i] += v; } }
-
问题描述
给定一个有N个正整数的序列A(N <= 10^5,A[i] <= 10^5),对序列中的每个数,求出序列中它左边比它小的数的个数。例如对序列{2,5,1,3,4},A[1]等于2,在A[1]左边比A[1]小的数有0个;A[2] = 5,在A[2]左边比A[2]小的数有1个,即2;
-
代码如下:
#include <cstdio> #include <cstring> const int maxn = 10010; #define lowbit(i) ((i)&(-i)) // lowbit写成宏定义的形式,注意括号 int C[maxn]; // 树状数组 void update(int x,int v){ for(int i = x ; i < maxn ; i+=lowbit(i)){ C[i] += v; } } int getSum(int x){ int sum = 0; for(int i = x ; i > 0 ; i -= lowbit(i)){ sum += C[i]; } return sum; } int main(){ int n,x; scanf("%d",&n); memset(C,0,sizeof(C)); // 树状数组初始值为0 for(int i = 0 ; i < n ; i++){ scanf("%d",&x); update(x,1); // x出现的次数加1 printf("%d\n",getSum(x-1)); // 查询当前小于x的个数 } return 0; }
-
运行结果
-
离散化代码
#include <cstdio> #include <cstring> #include <algorithm> using namespace std; const int maxn = 100010; #define lowbit ((x)&(-x)) struct Node{ int val; // 序列元素的值 int pos; // 原始序号 }temp[maxn]; // temp数组临时存放输入数据 int A[maxn]; // 离散化后的原始数组 int C[maxn]; // 树状数组 void update(int x,int v){ for(int i = x ; i < maxn ; i += lowbit(i)){ C[i] += v; } } int getSum(int x){ int sum = 0; for(int i = x ; i > 0 ; i -= lowbit(i)){ sum += C[i]; } return sum; } bool cmp(Node a,Node b){ return a.val < b.val; } int main(){ int n; scanf("%d",&n); memset(C,0,sizeof(C)); for(int i = 0 ; i < n; i++){ scanf("%d",&temp[i].val); temp[i].pos = i; } // 离散化 sort(temp,temp+n,cmp); for(int i = 0 ; i < n ; i++){ if(i == 0 || temp[i].val != temp[i-1].val){ A[temp[i].pos] = i+1; // 注意:这里必须从1开始 }else{ A[temp[i].pos] = A[temp[i-1].pos]; } } // 正式进入更新求和操作 for(int i = 0 ; i < n ; i++){ update(A[i],1); // A[i]出现的次数加1 printf("%d\n",getSum(A[i]-1)); } return 0; }
-
二维树状数组的代码实现
int C[maxn][maxn]; void update(int x,int y,int v){ for(int i = x ; i < maxn ; i += lowbit(i)){ for(int j = y ; j < maxn ; j += lowbit(j)){ C[i][j] += v; } } } int getSum(int x,int y){ int sum = 0; for(int i = x ; i > 0 ; i -= lowbit(i)){ for(int j = y ; j > 0 ; j -= lowbit(j)){ sum += C[i][j]; } } return sum; }
-
区间更新和单点查询
-
设计函数getSum(x),返回A[x]
-
设计函数update(x,v),将A[1]~A[x]的每个数都加上一个数v。
-
代码如下:
int getSum(int x){ int sum = 0; for(int i = x ; i < maxn ; i += lowbit(i)){ sum += c[i]; } return sum; } void update(int x,int v){ for(int i = x ; i > 0 ; i -= lowbit(i)){ c[i] += v; } }
-