「LeetCode」树状数组 - Binary Indexed Tree

2020-05-04
摘要: 本文介绍了树状数组及其应用。

树状数组 Binary Index Tree#

树状数组(Binary Index Tree - BIT)被用于解决动态的前缀和问题。也就是说我们完成两个主要操作:

  • 查询 $\sum_{k=1}^{n}num[k]$,$O(\log(n))$
  • 修改 $num[i]$,$O(\log(n))$

为了完成上述操作,我们要去维护我们的树状数组(而不是维护原数组,当然也可以同时维护),以保证它的性质。

原理#

定义#

我们对于原数组 $num$,根据如下规则构建一个数组 $A$。$A[i]$ 的值为 $num$ 数组中,从下标 $i$ 开始往前 $k$ 个数的和。其中,$k$ 被解释为 $i$ 的二进制表示中最后一个 $1$ 所在位的权。例如 $i=6$,其二进制表示为 $110$,则最后一个 $1$ 的权为 $2^1=2$,于是 $A[6]$ 的值为 $num[5] + num[6]$。(另一种说法是 $k=2^t$,$t$ 为 $i$ 二进制表示最后 $0$ 的个数。不难证明,它们是等价的)

查询#

对于一个要查询的下标,如 $1001011$,首先这个数的大小就是各个 $1$ 的权重之和。我们再做如下操作:每次抹去最后一个 $1$,并将其替换为 $0$,得到下一个数。于是针对前面提到的数,我们一共可以得到 $4$ 个数:$1001011,1001010,1001000,1000000$。根据上面提到的性质,这四个数中包含的个数取决于最后一位 $1$ 的权重,所以这四个数恰好组成了 $1001011$ 个数,而且它们是连续的,因此取这些下标处的数之和就是前 $1001011$ 个数的前缀和。

简单来说,我们只要每次替换最后一个 $1$ 为 $0$,再取数组中对应的值,相加即可。我们以 $1011(11)$ 为例,求 $num[1]+num[2]+…+num[11]$ 只要计算 $A[11]+A[10]+A[8]$ 即可:

$1011\ A[11]=num[11]$

$1010\ A[10]=num[10]+num[9]$

$1000\ A[8]=num[1]+num[2]+…+num[8]$

修改#

修改的操作与查询的操作是相反的。查询是在最后一位 $1$ 处减 $1$,而修改是在最后一位 $1$ 处加 $1$。

以修改 $011(3)$ 为例:

修改 $011\ A[3]$

修改 $100\ A[4]$

修改 $1000\ A[8]$

…​

直到超出 $n$ 为止。

构建#

我们利用修改的性质来构造一个树状数组。对一个数来说,在最后一位 $1$ 上加 $1$ 便是它的父节点。观察示意图,利用这个性质,只要 $O(n)$ 时间即可构建它。具体参考实现中的 init 函数。

实现#

注意,如果我们可以破坏原数组,其实可以使用 destructive 版本的 destructiveInit函数,它只需要接受原数组即可。

vector<long long> num(n+1, 0), A(n+1, 0);

// 返回末尾的 1
// x:        ....1000
// ~x:       ----0111
// ~x+1:     ----1000
// x&(~x+1): 00001000
int lowbit(int x) { return x & -x; }

// 构建,num 为原数组,A 为树状数组
void init(vector<long long>& num, vector<long long>& A) {
    int n = (int)num.size() - 1;
    for (int i = 1; i <= n; ++i) {
        A[i] += num[i];

        int j = i + lowbit(i);
        if (j <= n) {
            A[j] += A[i];
        }
    }
}

// 直接改造原数组
void destructiveInit(vector<long long>& A) {
    int n = (int)A.size() - 1;
    for (int i = 1; i <= n; ++i) {
        int j = i + lowbit(i);
        if(j <= n) {
            A[j] += A[i];
        }
    }
}

long long getSum(vector<long long>& A, int index) {
    long long ans = 0;
    while (index) {
        ans += A[index];
        index -= lowbit(index);
    }

    return ans;
}

void update(vector<long long>& A, int index, int value) {
    int n = (int)A.size() - 1;
    while (index <= n) {
        A[index] += value;
        index += lowbit(index);
    }
}

其他性质#

树状数组不仅可以用于动态查询前缀和,还有许多其他应用。

单点修改/区间查询#

区间 $[i,j]$ 的和就是 $\mathrm{sum}(j) - \mathrm{sum}(i-1) $,所以前缀和可以快速地计算出区间和。

区间修改/单点查询#

因为一个原数组 $A$ 的差分数组 $B$ ($B[i]=A[i]-A[i-1]$) 的前缀和(因此要维护的树状数组是从差分数组 $B$ 得到的)即为 $A[i]$。利用该性质,我们可以做到单点查询。而又有如下性质:

如果我们要对区间 $[i,j]$ 中的每个数都增加 $k$,只需要另 $B[i]=B[i]+k,B[j+1]=B[j+1]-k$ 即可。为什么?因为我们做单点查询的时候计算的是 $B$ 的前缀和,此时 $[1,i-1]$ 的单点查询结果未受到影响,$[i,j]$ 单点查询结果都增加 $k$,$[j+1,…,n]$ 单点查询结果未受到影响。(实际上我们维护的是 $B$ 产生的树状数组)

所以树状数组也可以快速完成区间修改的操作。

区间修改/区间查询#

在上述基础上,我们如果要求区间和:

$$\begin{array}{c}\sum_{i=1}^{r} A_{i} \\ =\sum_{i=1}^{r} \sum_{j=1}^{i} B_{j} \\ =\sum_{i=1}^{r} B_{i} \times(r-i+1) \\ =\sum_{i=1}^{r} B_{i} \times(r+1)-\sum_{i=1}^{r} B_{i} \times i\end{array}$$

于是我们只需要两个树状数组来维护 $\sum B_{i}, \sum B_{i} \times i$ 即可。

// 初始化需要两个数组
void init(vector<long long>& num, vector<long long>& A1, vector<long long>& A2) {
    int n = (int)num.size() - 1;
    for (int i = 1; i <= n; ++i) {
        A1[i] += num[i];
        A2[i] += num[i] * i;
        
        int j = i + lowbit(i);
        if(j <= n) {
            A1[j] += A1[i];
            A2[j] += A2[i] ;
        }
    }
}

// A[b,c] 都加 d
update(A1, b, d);
update(A1, c+1, -d);
update(A2, b, d*b);
update(A2, c+1, -d*(c+1));

// A[b,c] 的和
((c+1) * getSum(A1, c) - getSum(A2, c)) - (b * getSum(A1, b-1) - getSum(A2, b-1))

二维树状数组#

我们把前缀和扩展至二维,$\mathrm{sum}(x,y)$ 代表前 $x$ 行 $y$ 列的和。于是也可以根据前缀和计算区间和,如右图所示。

单点修改/区间查询#
long long getSum(vector<vector<long long>>& A, int x, int y) {
    long long ans = 0;
    while(x) {
        int _y = y;
        while(_y) {
            ans += A[x][_y];
            _y -= lowbit(_y);
        }
        x -= lowbit(x);
    }
    
    return ans;
}

void update(vector<vector<long long>>& A, int x, int y, long long value) {
    int n = (int)A.size() - 1, m = (int)A[0].size() - 1;
    while(x <= n) {
        int _y = y;
        while(_y <= m) {
            A[x][_y] += value;
            _y += lowbit(_y);
        }
        x += lowbit(x);
    }
}
区间修改/单点查询#

我们对一维数组进行差分,是因为前缀和等于原数组的值。我们利用以下公式对二维数组进行差分:

$d[i][j] = a[i]][j] - a[i-1][j] - a[i][j-1]+a[i-1][j-1]$

这样求和就是 $a[i][j]$。

那我们如何处理区间修改呢?如果我们要给中间的 $3*3$ 矩阵增加 $x$,只需要对差分数组做如下操作:

0  0  0  0  0
0 +x  0  0 -x
0  0  0  0  0
0  0  0  0  0
0 -x  0  0 +x

不难验证,这样的效果是:

0  0  0  0  0
0  x  x  x  0
0  x  x  x  0
0  x  x  x  0
0  0  0  0  0

具体应用参阅后文例题。

区间修改/区间查询#

$\sum_{i=1}^{x}\sum_{j=1}^{y}\sum_{k=1}^{i}\sum_{h=1}^{j}d[h][k]$ 代表了点 $(x,y)$,它具有 $O(n^4)$ 的复杂度。利用树状数组可以优化到 $O(\log^2n)$。我们绘制一个简单的示意图,统计一下 $d[h][k]$ 出现的次数即可(打开最内侧的两个求和)。我们可以知道 $d[h][k]$ 出现了 $ (x - h + 1)*(y - k + 1)$ 次。所以我们可以写成: $\sum_{i=1}^{x}\sum_{j=1}^{y}d[i][j] * (x + 1 - i) * (y + 1 - j)$。我们拆开就得到了:

$$\begin{array}{c}(x+1) *(y+1) * \sum_{i=1}^{x} \sum_{j=1}^{y} d[i][j] \\ -(y+1) * \sum_{i=1}^{x} \sum_{j=1}^{y} d[i][j] * i \\ -(x+1) * \sum_{i=1}^{x} \sum_{j=1}^{y} d[i][j] * j \\ +\sum_{i=1}^{x} \sum_{j=1}^{y} d[i][j] * i * j\end{array}$$

于是我们需要分别维护 $d[i][j], d[i][j] * i, d[i][j] * j, d[i][j] * i * j$。具体应用参阅后文例题。

区间最值问题#

据说更推荐线段树来实现,还没学线段树。相对来说,树状数组功能比线段树少,但更简单,常数更小。

区间第 k 大问题#

因为我们可以通过知道计算前缀和的方式,动态计算出小于 $x$ 的数的个数。如果要找第 $k$ 大的数,我们目标就是找到第一个满足 getSum(n) - getSum(x) < k 的位置 $x$(特殊地,对于无重复数字的数组来说,getSum(n) - getSum(x) == k-1),也就是对于 getSum(n) - getSum(x)upper_bound(..., k)(对于降序数组来说,upper_bound(..., k) 代表第一个小于 $k$ 的数)。如果利用二分查找,复杂度为 $O(\log^2(N))$。

注意,如上图所示,我们如果要找到第三大的数,需要找到 getSum(n) - getsum(x) < 3 的第一个位置。其中 3, 4, 5, 6 都满足 getSum(n) - getsum(x) == 2,所以我们限定取找到的第一个位置。也就是 upper_bound(..., k)

我们可以手动写一个 upper_bound 函数,参考:这篇文章。应用见例题 KiKi’s K-Number

逆序对#

树状数组还能解决求逆序对的问题。之前我们是用归并排序来解决这个问题的。

思路如下,对于序列 $x_1x_2,\cdots,x_ix_{i+1},\cdots x_n$ 来说,我们如果要求以 $x_i$ 为左边界的逆序对,只需要知道 $x_i$ 右侧比它小的数($0,1,2,\cdots,x_i-1$)有几个。我们可以通过从右向左遍历同时计数并求前缀和来完成这个过程。遍历到 $x_i$ 时,执行 cnt += getSum(x[i]-1),然后再执行 update(x[i], 1),最后 cnt 就是逆序对的个数。总的时间复杂度$O(N\log(N))$,与归并排序相同。

由于数据可能存在负数,不能用于计数。于是我们可以利用离散化,将数据范围映射到正整数,同时节省计数数组的空间。

具体见例题 计算右侧小于当前元素的个数Cows

例题#

模版题 树状数组 1:单点修改,区间查询#

树状数组 1:单点修改,区间查询

注意:也可以使用 destructiveInit 作为构造函数。

#include <iostream>
#include <vector>

using namespace std;

int lowbit(int x) {
    return x & -x;
}

void init(vector<long long>& num, vector<long long>& A) {
    int n = (int)num.size() - 1;
    for (int i = 1; i <= n; ++i) {
        A[i] += num[i];
        
        int j = i + lowbit(i);
        if(j <= n) {
            A[j] += A[i];
        }
    }
}

long long getSum(vector<long long>& A, int index) {
    long long ans = 0;
    while(index) {
        ans += A[index];
        index -= lowbit(index);
    }
    
    return ans;
}

void update(vector<long long>& A, int index, int value) {
    int n = (int)A.size() - 1;
    while(index <= n) {
        A[index] += value;
        index += lowbit(index);
    }
}

int main()
{
    ios::sync_with_stdio(false);
    int n, q;
    cin >> n >> q;
    vector<long long> num(n+1, 0), A(n+1, 0);
    
    for(int i = 1; i <= n; ++i)  {
        cin >> num[i];
    }
    
    init(num, A);
    
    for(int i = 1; i <= q; ++i)  {
        int a, b, c;
        cin >> a >> b >> c;
        
        if (a == 1) {
            update(A, b, c);
        } else {
            cout << getSum(A, c) - getSum(A, b-1) << endl;
        }
    }
    
    return 0;
}

模版题 树状数组 2:区间修改,单点查询#

树状数组 2:区间修改,单点查询

注意:也可以使用 destructiveInit 作为构造函数。

#include <iostream>
#include <vector>

using namespace std;

/* 基本函数同上题,不再给出 */
  
int main()
{
    ios::sync_with_stdio(false);
    int n, q;
    cin >> n >> q;
    vector<long long> num(n+1, 0), diffNum(n+1, 0), A(n+1, 0);
    
    for(int i = 1; i <= n; ++i)  {
        cin >> num[i];
        diffNum[i] = num[i] - num[i-1];
    }
    
    init(diffNum, A);
    
    for(int i = 1; i <= q; ++i)  {
        int a, b, c, d;
        cin >> a >> b;
        
        if (a == 1) {
            cin >> c >> d;
            update(A, b, d);
            update(A, c+1, -d);
        } else {
            cout << getSum(A, b) << endl;
        }
    }
    
    return 0;
}

模版题 树状数组 3:区间修改,区间查询#

树状数组 3:区间修改,区间查询

// 只有 init 和 update 有区别,其他同上。
void init(vector<long long>& num, vector<long long>& A1, vector<long long>& A2) {
    int n = (int)num.size() - 1;
    for (int i = 1; i <= n; ++i) {
        A1[i] += num[i];
        A2[i] += num[i] * i;
        
        int j = i + lowbit(i);
        if(j <= n) {
            A1[j] += A1[i];
            A2[j] += A2[i] ;
        }
    }
}

// long long 才能过。因为 A2 更新的值可能超过 int
void update(vector<long long>& A, int index, long long value) {
    int n = (int)A.size() - 1;
    while(index <= n) {
        A[index] += value;
        index += lowbit(index);
    }
}

int main()
{
    ios::sync_with_stdio(false);
    int n, q;
    cin >> n >> q;
    vector<long long> diffNum(n+1, 0), A1(n+1, 0), A2(n+1, 0);
    
    long long t1 = 0, t2 = 0;
    for(int i = 1; i <= n; ++i)  {
        cin >> t2;
        diffNum[i] = t2 - t1;
        t1 = t2;
    }
    
    init(diffNum, A1, A2);
    
    for(int i = 1; i <= q; ++i)  {
        int a, b, c;
        long long d;
        cin >> a >> b >> c;
        
        if (a == 1) {
            cin >> d;
            update(A1, b, d);
            update(A1, c+1, -d);
            update(A2, b, d*b);
            update(A2, c+1, -d*(c+1));
        } else {
            cout <<  ((c+1) * getSum(A1, c) - getSum(A2, c)) - (b * getSum(A1, b-1) - getSum(A2, b-1)) << endl;
        }
    }
    
    return 0;
}

模版题 二维树状数组 1:单点修改,区间查询#

二维树状数组 1:单点修改,区间查询

int main()
{
    ios::sync_with_stdio(false);
    int n, m;
    cin >> n >> m;
    
    vector<vector<long long>> A(n+1, vector<long long>(m+1, 0));
    
    int type;
    while (cin >> type) {
        if (type == 1) {
            int x, y, k;
            cin >> x >> y >> k;
            update(A, x, y, k);
        } else {
            int a, b, c, d;
            cin >> a >> b >> c >> d;
            cout << getSum(A, c, d) - getSum(A, c, b-1) - getSum(A, a-1, d) + getSum(A, a-1, b-1) << endl;
        }
    }
    
    return 0;
}

模版题 二维树状数组 2:区间修改,单点查询#

二维树状数组 2:区间修改,单点查询

我们需要维护 $d[i][j] = a[i]][j] - a[i-1][j] - a[i][j-1]+a[i-1][j-1]$。

int main()
{
    ios::sync_with_stdio(false);
    int n, m;
    cin >> n >> m;
    
    vector<vector<long long>> A(n+1, vector<long long>(m+1, 0));
    
    int type;
    while (cin >> type) {
        if (type == 1) {
            int a, b, c, d, k;
            cin >> a >> b >> c >> d >> k;
            
            update(A, a, b, k);
            update(A, c+1, b, -k);
            update(A, a, d+1, -k);
            update(A, c+1, d+1, k);
        } else {
            int x, y;
            cin >> x >> y;
            cout << getSum(A, x, y) << endl;
        }
    }
    
    return 0;
}

模版题 二维树状数组 3:区间修改,区间查询#

二维树状数组 3:区间修改,区间查询

我们需要分别维护 $d[i][j], d[i][j] * i, d[i][j] * j, d[i][j] * i * j$。


long long get4Sum(vector<vector<long long>>& A1,
                  vector<vector<long long>>& A2,
                  vector<vector<long long>>& A3,
                  vector<vector<long long>>& A4,
                  int x, int y) {
    return (x+1)*(y+1)*getSum(A1, x, y)
            - (y+1)*getSum(A2, x, y)
            - (x+1)*getSum(A3, x, y)
            + getSum(A4, x, y);
}

int main()
{
    ios::sync_with_stdio(false);
    int n, m;
    cin >> n >> m;
    
    vector<vector<long long>> A1(n+1, vector<long long>(m+1, 0)),
                              A2(n+1, vector<long long>(m+1, 0)),
                              A3(n+1, vector<long long>(m+1, 0)),
                              A4(n+1, vector<long long>(m+1, 0));
    
    int type;
    while (cin >> type) {
        if (type == 1) {
            int a, b, c, d, k;
            cin >> a >> b >> c >> d >> k;
            
            update(A1, a, b, k);
            update(A1, c+1, b, -k);
            update(A1, a, d+1, -k);
            update(A1, c+1, d+1, k);
            
            update(A2, a, b, k*a);
            update(A2, c+1, b, -k*(c+1));
            update(A2, a, d+1, -k*a);
            update(A2, c+1, d+1, k*(c+1));
            
            update(A3, a, b, k*b);
            update(A3, c+1, b, -k*b);
            update(A3, a, d+1, -k*(d+1));
            update(A3, c+1, d+1, k*(d+1));
            
            update(A4, a, b, k*a*b);
            update(A4, c+1, b, -k*(c+1)*b);
            update(A4, a, d+1, -k*a*(d+1));
            update(A4, c+1, d+1, k*(c+1)*(d+1));
        } else {
            int a, b, c, d;
            cin >> a >> b >> c >> d;
            
            cout << get4Sum(A1, A2, A3, A4, c, d)
                - get4Sum(A1, A2, A3, A4, c, b-1)
                - get4Sum(A1, A2, A3, A4, a-1, d)
                + get4Sum(A1, A2, A3, A4, a-1, b-1) << endl;
        }
    }
    
    return 0;
}

POJ 2352 Stars#

POJ 2352 Stars

仔细观察题意,因为 $y$ 是递增的,因此只需要考察所有 $x$ 比自己小的位置即可,不需要使用二维数组(ME)。注意我们令 $x$ 自增 $1$ 是因为坐标是从 $0$ 开始的。

#include <iostream>
#include <vector>

using namespace std;

int lowbit(int x) {
    return x & -x;
}

int getSum(vector<int>& A, int index) {
    int ans = 0;
    while(index) {
        ans += A[index];
        index -= lowbit(index);
    }
    
    return ans;
}

void update(vector<int>& A, int index, int value) {
    int n = (int)A.size() - 1;
    while(index <= n) {
        A[index] += value;
        index += lowbit(index);
    }
}

int main()
{
    ios::sync_with_stdio(false);
    int n;
    cin >> n;

    vector<int> A(32002, 0);
    vector<int> level(n, 0);

    while(n--) {
        int x, y;
        cin >> x >> y;
        
        x += 1;
        level[getSum(A, x)]++;
        update(A, x, 1);
    }
    
    for (int i = 0; i < level.size(); ++i) {
        cout << level[i] << endl;
    }
    
    return 0;
}

SDOI 2009 HH 的项链#

SDOI 2009 HH 的项链

题目描述#

HH 有一串由各种漂亮的贝壳组成的项链。HH 相信不同的贝壳会带来好运,所以每次散步完后,他都会随意取出一段贝壳,思考它们所表达的含义。HH 不断地收集新的贝壳,因此,他的项链变得越来越长。

有一天,他突然提出了一个问题:某一段贝壳中,包含了多少种不同的贝壳?这个问题很难回答…… 因为项链实在是太长了。于是,他只好求助睿智的你,来解决这个问题。

思路#

因为数据很大.. (不愧是比赛题,我只是做着玩儿的…)所以对时间的要求很高。

一开始是想到如果贝壳种类比较少,可以直接用一个64位(比如种类不超过 64 种,否则应该用 n 位)的 Bit Mask 来代表种类,此时只要把前缀和的求和操作变为求或操作,最后统计前缀或中 $1$ 的个数即可。但是.. 种类高达 $10^6$ 种。

于是参考题解中的思路:

如果项链为:1 2 3 4 3 5

对于区间 $l, r$ 来说,在 $[0,r]$ 的范围内,我们只关心出现的次数的最右一次(这样能保证尽量落在 $[l,r]$ 内)。如果我们用一个数组 $A$ 来记录每个数最后出现的位置,比如当 $r=4$ 的时候,数组的内容为 1 1 1 1 0 0;比如当 $r=5$ 的时候,数组的内容为 1 1 0 1 1 0;$r=6$ 的时候,数组的内容为 1 1 0 1 1 1。

所以我们只需要对所有 query 的右边界进行重新排序,依次统计前缀和并记录下来即可(最后要按输入的顺序输出)。

在遍历的过程中,我们需要维护数组 $A$,把当前 cursor 的位设为 $1$,把与当前数重复的前一个数的位设为 0,于是我们需要记录上一个重复数出现的位置。

对于实现的细节,尝试了许多途径,因为 TLE,所以最后都采取用空间换时间的策略,也就是说,开数组。

#include <iostream>
#include <stdio.h>
#include <vector>
#include <algorithm>

using namespace std;

int lowbit(int x) {
    return x & -x;
}

int getSum(vector<int>& A, int index) {
    int ans = 0;
    while(index) {
        ans += A[index];
        index -= lowbit(index);
    }
    
    return ans;
}

void update(vector<int>& A, int index, int value) {
    int n = (int)A.size() - 1;
    while(index <= n) {
        A[index] += value;
        index += lowbit(index);
    }
}

struct query {
    int l;
    int r;
    int pos;
    
    query() {
        
    }
    
    query(int l, int r, int pos) {
        this->l = l;
        this->r = r;
        this->pos = pos;
    }
    
    bool operator <(const query& rhs) const{
        return r < rhs.r;
    }
};

int main()
{
    int n;
    scanf("%d", &n);
    
    vector<int> num(n+1, 0);
    vector<int> prev(n+1, 0); // 存储下标为 i 的数字上次出现的下标
    vector<int> prevv(1000001, 0); // 存储某个数上次出现的位置
    
    for (int i = 1; i <= n; ++i) {
        scanf("%d", &num[i]);
        prev[i] = prevv[num[i]];
        prevv[num[i]] = i;
    }
    
    int queryCount;
    scanf("%d", &queryCount);
    
    // 记录所有的 query
    vector<query> querys(queryCount);
    for (int i = 0; i < queryCount; ++i) {
        int l, r;
        scanf("%d%d", &l, &r);
        querys[i].l = l;
        querys[i].r = r;
        querys[i].pos = i;
    }
    
      // 将 query 按右边界排序
    sort(querys.begin(), querys.end());
    
   // 为了顺序输出,提前开好 output 数组
    vector<int> A(n+1, 0), output(queryCount, 0);
    int cursor = 0;
    for (int i = 0; i < queryCount; ++i) {
        while(cursor < querys[i].r) {
            cursor++;
            update(A, cursor, 1);
            
            if (prev[cursor]) {
                update(A, prev[cursor], -1);
            }
        }
        output[querys[i].pos] = getSum(A, querys[i].r) - getSum(A, querys[i].l - 1);
    }
    
    for (int i = 0; i < querys.size(); ++i) {
        printf("%d\n",output[i]);
    }
    
    return 0;
}

HEOI2012采花#

HEOI2012采花

思路#

题目与上一题十分类似,只不过区间内数出现两次才算数。一开始的思路是只要记录出现了至少两次数的最后一个位置,但是这样说错误的,比如序列 2, 2, 3,如果我们在第二次出现 2 的位置上 +1,变成 (0, 1, 0),当询问 $[2,3]$ 就出错了,其实应该在倒数第二次出现的位置 +1 变成 (1, 0, 0)。因为当我们确定 $r$ 的时候,我们在意的是 $l$ 与倒数第二次出现的数的位置关系。所以只要稍作修改,维护倒数第二个数的状况即可。这题数据更大,总有一两个数据点 TLE,开启 O2 优化可以过.. 但是好像比赛是不允许开的?.. 看了一下代码逻辑和别人的解法一样,咱也看不懂 dl 们花里胡哨的写法,为啥别人能过呢。

#pragma clang optimize on
#pragma GCC optimize(2)

#include <iostream>
#include <stdio.h>
#include <vector>
#include <algorithm>

using namespace std;

int lowbit(int x) {
    return x & -x;
}

int getSum(vector<int>& A, int index) {
    int ans = 0;
    while(index) {
        ans += A[index];
        index -= lowbit(index);
    }
    
    return ans;
}

void update(vector<int>& A, int index, int value) {
    int n = (int)A.size() - 1;
    while(index <= n) {
        A[index] += value;
        index += lowbit(index);
    }
}

struct query {
    int l;
    int r;
    int pos;
    
    query() {
        
    }
    
    query(int l, int r, int pos) {
        this->l = l;
        this->r = r;
        this->pos = pos;
    }
    
    bool operator <(const query& rhs) const{
        return r < rhs.r;
    }
};

int main()
{
    int n, m, c;
    scanf("%d%d%d", &n, &m, &c);
    
    vector<int> prev(n+1, 0); // 存储下标为 i 的数字上次出现的下标
    vector<int> prevv(m+1, 0); // 存储某个数上次出现的位置
    
    int num;
    for(int i = 1; i <= n; ++i) {
        scanf("%d", &num);
        prev[i] = prevv[num];
        prevv[num] = i;
    }
    
    vector<query> querys(c);
    for(int i = 0; i < c; ++i) {
        scanf("%d%d", &querys[i].l, &querys[i].r);
        querys[i].pos = i;
    }
    
    sort(querys.begin(), querys.end());
    
    vector<int> A(n+1, 0), output(c, 0);
    int cursor = 0;
    for(int i = 0; i < c; ++i) {
        int l = querys[i].l, r = querys[i].r, pos = querys[i].pos;
        while(cursor < r) {
            cursor++;
            
            int prevIndex = prev[cursor];
            if(prevIndex) {
                update(A, prevIndex, 1);
                int prevPrevIndex = prev[prevIndex];
                if(A[prevPrevIndex])  update(A, prevPrevIndex, -1);
            }
        }
        
        output[pos] = getSum(A, r) - getSum(A, l - 1);
    }
    
    for (int i = 0; i < c; ++i) {
        printf("%d\n", output[i]);
    }
    
    return 0;
}

KiKi’s K-Number#

KiKi’s K-Number

注意.. 出题人把 Element 拼成了 Elment… 另外,题目中的 a, k 是指比 $a$ 大的第 $k$ 个数,而不是比 $a$ 大的第 $k$ 大的数。

我们可以简单地把题意转化为求第 total-getSum(A)+1 大的数。注意,删除的时候,也要通过 find(A, a, 0) 来判断这个数存不存在。

#pragma clang optimize on
#pragma GCC optimize(2)

#include <stdio.h>
#include <vector>
#include <algorithm>
#include <stack>

using namespace std;

int lowbit(int x) {
    return x & -x;
}

int getSum(vector<int>& A, int index) {
    int ans = 0;
    while(index) {
        ans += A[index];
        index -= lowbit(index);
    }
    
    return ans;
}

void update(vector<int>& A, int index, int value) {
    int n = (int)A.size() - 1;
    while(index <= n) {
        A[index] += value;
        index += lowbit(index);
    }
}

int my_upper_bound(vector<int>& num, int k) {
    int l = 0, r = (int)num.size(), n = (int)num.size()-1, total = getSum(num, n);
    while (l < r) {
        int mid = (l + r) / 2;
        if (total - getSum(num, mid) >= k) {
            l = mid + 1;
        } else {
            r = mid;
        }
    }
    return l;
}

// 比 a 大的第 k 个数
int find(vector<int>& num, int a, int k) {
    int sum_a = getSum(num, a), n = (int)num.size()-1, total = getSum(num, n);
    int index = my_upper_bound(num, total - sum_a - k + 1);
    return index;
}

int main()
{
    int n, MAXN = 100000;
    while(scanf("%d", &n) != EOF) {
        vector<int> A(MAXN + 1, 0);
        int total = 0;
        for(int i = 0; i < n; ++i) {
            int type;
            scanf("%d", &type);
            if (type == 0) {
                int index;
                scanf("%d", &index);
                update(A, index, 1);
                total++;
            } else if (type == 1) {
                int a;
                scanf("%d", &a);
                
                int res = find(A, a, 0);
                if(res == a) {
                    update(A, a, -1);
                    total--;
                } else {
                    printf("No Elment!\n");
                }
            } else if (type == 2) {
                int a, k;
                scanf("%d%d", &a, &k);
                
                int res = find(A, a, k);
                if(res == MAXN + 1) {
                    printf("Not Find!\n");
                } else {
                    printf("%d\n", res);
                }
            }
        }
    }
    return 0;
}

315. 计算右侧小于当前元素的个数#

315. 计算右侧小于当前元素的个数

这是利用树状数组求逆序对的例子。

class Solution {
public:
    int lowbit(int x) {
        return x & -x;
    }

    int getSum(vector<int>& A, int index) {
        int ans = 0;
        while(index) {
            ans += A[index];
            index -= lowbit(index);
        }
        return ans;
    }

    void update(vector<int>& A, int index, int value) {
        int n = A.size() - 1;
        while(index <= n) {
            A[index] += value;
            index += lowbit(index);
        }
    }

    vector<int> countSmaller(vector<int>& nums) {
        int n = nums.size();
      
        vector<int> _nums(nums.begin(), nums.end()), A(n+1, 0), res(n, 0);
        
        // 离散化
        sort(_nums.begin(), _nums.end());
        unique(_nums.begin(), _nums.end());
        for(int i = 0; i < n; ++i) {
            nums[i] = lower_bound(_nums.begin(), _nums.end(), nums[i]) - _nums.begin() + 1;
        }
        
        for(int i = n-1; i >= 0; --i) {
            res[i] = getSum(A, nums[i]-1);
            update(A, nums[i], 1);
        }

        return res;
    }
};

POJ 2481 Cows#

POJ 2481 Cows

我们针对区间右端点进行降序排序,若区间右端点相同,则左端点以升序排序。这样以来,我们遍历排序好的数组,我们就能保证右端点大于等于 $\mathrm{intervals}[i].r$ 的区间都已经被遍历过(若右端点相等,因为左端点以升序排序,所有可能符合题意的区间也已经遍历过),只需要计算这些区间中 $l$ 不超过 $ \mathrm{intervals}[i].l$ 的个数即可(但是左右端点都相等需要特殊处理)。这个可以通过计数与求前缀和实现(树状数组)。另外针对右侧区间相等的情况,我们借助一个缓冲区(比如栈)来实现。当满足条件:s.top().r != intervals[i].r || s.top().l != intervals[i].l) 时,也就是说,右端点不相同时,或者右端点相同且左端点不相同时,再把缓冲区的数据更新至树状数组中。

#pragma clang optimize on
#pragma GCC optimize(2)

#include <stdio.h>
#include <vector>
#include <algorithm>
#include <stack>

using namespace std;

int lowbit(int x) {
    return x & -x;
}

int getSum(vector<int>& A, int index) {
    int ans = 0;
    while(index) {
        ans += A[index];
        index -= lowbit(index);
    }
    
    return ans;
}

void update(vector<int>& A, int index, int value) {
    int n = (int)A.size() - 1;
    while(index <= n) {
        A[index] += value;
        index += lowbit(index);
    }
}

struct interval {
    int l;
    int r;
    int pos;
    
    interval() {
        
    }
    
    interval(int l, int r, int pos) {
        this->l = l;
        this->r = r;
        this->pos = pos;
    }
    
    bool operator <(const interval& rhs) const{
        return (r > rhs.r) || (r == rhs.r && l < rhs.l);
    }
};

int main()
{
    int n;
    
    while(scanf("%d", &n)) {
        if(!n) return 0;
        
        vector<interval> intervals(n);
        for(int i = 0; i < n; ++i) {
            scanf("%d%d", &intervals[i].l, &intervals[i].r);
            intervals[i].l++;
            intervals[i].r++;
            intervals[i].pos = i;
        }
        getchar();
        
        // Sort by right-end
        sort(intervals.begin(), intervals.end());
        
        vector<int> A(100001, 0), output(n+1, 0);
        stack<interval> s;
        for (int i = 0; i < n; ++i) {
            if(!s.empty() && (s.top().r != intervals[i].r || s.top().l != intervals[i].l)) {
                while(!s.empty()) {
                    update(A, s.top().l, 1);
                    s.pop();
                }
            }
            s.push(intervals[i]);
            output[intervals[i].pos] = getSum(A, intervals[i].l);
        }
        
        printf("%d", output[0]);
        for (int i = 1; i < n; ++i) {
            printf(" %d", output[i]);
        }
        printf("\n");
    }
    return 0;
}

参考#

https://www.cnblogs.com/RabbitHu/p/BIT.html

http://www.cppblog.com/menjitianya/archive/2015/11/02/212171.html