0%

树状数组

解决什么问题?

单点更新,计算前缀和。

原理

灵神讲解

树状数组图示

理解与总结

结合灵神文章和视频,尝试总结一下。

在有单点更新的计算前缀和的问题当中,为了平衡「查询」和「更新」的时间复杂度,由每个数$i$的二进制长度为 $O(log\ i)$ 得到启发,考虑将每个区间 $[1,i]$ 划分成几个关键区间,每个区间的长度为 $i$ 的二进制中的 $1$ 们所代表的值。

同时,由于从 $[1,1]$ 到 $[1,i]$ 右端点每增加$1$的一步都只会新增加一个关键区间,所以一共有 $i$ 个关键区间。因此,tree[i] 就是代表的是右端点为 $i$ 的关键区间的元素和。

如何查询 $[1,i]$ 的元素和

$[1,i]$ 区间和元素和,可以看成是 $[1,i-lowbit(i)]$ (去掉 $lowbit(i)$ 后的区间)和 $[i-lowbit(i)+1,i]$ 两个区间的和。

其中,$[1,i-lowbit(i)]$ 是一个规模更小的子问题。

因此算法如下:

  1. 初始化元素和 $s = 0$。
  2. tree[i] 加入到 $s$ 中。
  3. 将 $i$ 去掉 $lowbit[i]$ ,即更新为 $i-lowbit$,重复上一步,直到$i$为 $0$。

如何更新单点$i$

更新 $i$ 不能只更新$i$,而是要更新含有 $i$ 的所有关键区间。

哪些结点所代表的区间包含 $i$ 呢?右端点为 $i$ 的区间的父节点们(结合灵神文章的图)。而每个结点 $i$ 和它的父节点之间的右端点相差的是 $lowbit(i)$。

因此更新算法如下:

  1. 更新tree[i]
  2. 将$i$加上 $lowbit[i]$,重复第一步,直到越界。

代码

lowbit 相关

lowbit:

1
lowbit = s & -s

解释:-s = ~s + 1,即 s 的负数是所有位按位取反末尾加一,得到的结果和原来的 s 相比除了最右边的 1 一样,其他都相反,那再和 s 按位与,就得到了 lowbit

减去 lowbit

1
i &= i - 1  # i -= i & -i 的另一种写法

下标问题

在代码中,tree 开 $n+1$ 位置,查询和更新的时候,认为它是下标从 $1$ 开始,即只用 $[1,n]$ 部分。

而且,tree 下标是必须从 $1$ 开始,因为当 $i=0$ 的时候,$lowbit(i)=0$,会发生死循环。

而外面(题目中)的数组的下标可以从 $0$ 也可以从$1$开始,如果从 $0$ 开始,则需要 $+1$,从 $1$ 开始就不用了。

具体的调用的时候,要看好到底是「原数组」的下标还是「tree」的下标。如果传入的是原数组的下标并且是从 $0$ 开始的,那就得 $+1$。

——上面说的好像有点绕,总结来说就是:$tree$ 的下标从 $1$ 开始,注意与题目中给的数组的下标转换。

倍增查找

1
2
3
4
5
6
7
8
9
10
11
12
13
14
// 树状数组倍增寻第 K 大: 找到前缀和恰好为 k 的下标
inline int find_kth(int k) {
int pos = 0;
int sum = 0;
// 2^22 = 4194304, 2^23 = 8388608. N 最大 5e6,从 22 开始足够
for (int i = 22; i >= 0; --i) {
int nxt = pos + (1 << i);
if (nxt <= max_N && sum + tree[nxt] < k) {
sum += tree[nxt];
pos = nxt;
}
}
return pos + 1;
}

以下解释来自 Gemini (题目是 洛谷 - P4901 排队):

这个倍增逻辑的核心在于:利用树状数组天然的“2的幂次”结构,直接在树上“跳跃”搜索。

普通的树状数组求和是把下标拆成 $2$ 的幂次相加,而倍增是反过来,通过凑 $2$ 的幂次来找下标。

1. 为什么 tree[nxt] 刚好就是一段区间的和?

树状数组的定义是:tree[x] 存储的是区间 (x - lowbit(x), x] 的和。

当我们站在 pos 位置,尝试往后跳 $2^i$ 步到 nxt = pos + (1 << i) 时:

如果 pos 是 $2^i$ 的倍数(或者从 0 开始跳),那么 tree[nxt] 存储的恰好就是 (pos, nxt] 这段长度为 $2^i$ 的区间和

2. 详细执行流程(以找第 $k=7$ 个人为例):

假设我们从 pos = 0, sum = 0 开始,从高位向低位看:

  • 尝试跳 $2^{22}, 2^{21} \dots$: 发现跳过去后,sum + tree[nxt] 远大于 $7$,所以不跳。
  • 尝试跳 $2^2 = 4$: 假设 tree[4](即前 4 个人的状态)等于 $4$。因为 $4 < 7$,我们跳过去
    • pos = 4, sum = 4
  • 尝试跳 $2^1 = 2$: 检查 nxt = 4 + 2 = 6。假设 tree[6] 存的是第 5 和第 6 个人的状态,和为 $2$。此时 sum + tree[6] = 4 + 2 = 6。因为 $6 < 7$,我们跳过去
    • pos = 6, sum = 6
  • 尝试跳 $2^0 = 1$: 检查 nxt = 6 + 1 = 7。假设 tree[7] 存的是第 7 个人的状态。如果第 7 个人还在,tree[7]=1,此时 sum + tree[7] = 7。因为 $7$ 不满足 < k (即 < 7),所以不跳
  • 循环结束: 得到的 pos = 6 是满足“前缀和严格小于 $k$”的最大下标。
  • 最终结果: pos + 1 = 7 就是我们要找的第 $k$ 个人。

3. 为什么要 sum + tree[nxt] < k 而不是 <= k

  • 如果用 < k:最后停在最后一个和小于 $k$ 的位置,结果是 pos + 1
  • 如果用 <= k:最后停在最后一个和等于 $k$ 的位置。但因为我们要找的是第一个达到 $k$ 的位置,所以 < k 的逻辑更稳健,尤其当数组中有 0 时(比如有人被抽走了)。

4. 复杂度

普通的二分是 $\mathcal{O}(\log^2 N)$(二分套树状数组查询)。

这种倍增直接利用 tree[] 数组的值,不需要重复计算前缀和,所以是严格的 $\mathcal{O}(\log N)$

模板

从空开始

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
# 模板来源 https://leetcode.cn/circle/discuss/mOr1u6/
class FenwickTree:
def __init__(self, n: int):
self.tree = [0] * (n + 1) # 使用下标 1 到 n

# a[i] 增加 val
# 1 <= i <= n
# 时间复杂度 O(log n)
def update(self, i: int, val: int) -> None:
t = self.tree
while i < len(t):
t[i] += val
i += i & -i

# 计算前缀和 a[1] + ... + a[i]
# 1 <= i <= n
# 时间复杂度 O(log n)
def pre(self, i: int) -> int:
t = self.tree
res = 0
while i > 0:
res += t[i]
i &= i - 1
return res

# 计算区间和 a[l] + ... + a[r]
# 1 <= l <= r <= n
# 时间复杂度 O(log n)
def query(self, l: int, r: int) -> int:
if r < l:
return 0
return self.pre(r) - self.pre(l - 1)

# 作者:灵茶山艾府
# 链接:https://leetcode.cn/discuss/post/3583665/fen-xiang-gun-ti-dan-chang-yong-shu-ju-j-bvmv/
# 来源:力扣(LeetCode)
# 著作权归作者所有。商业转载请联系作者获得授权,非商业转载请注明出处。

记忆方法:

  1. 查询是,想一想那个动态规划/子问题的思路,那么 $i$ 的变化是越来越小的,所以while i > 0,$i$ 要减去 $lowbit(i)$。
  2. 而更新则相反。

有初始值

$O(nlog\ n)$做法

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
class NumArray:
__slots__ = 'nums', 'tree'

def __init__(self, nums: List[int]):
n = len(nums)
self.nums = [0] * n # 使 update 中算出的 delta = nums[i]
self.tree = [0] * (n + 1)
for i, x in enumerate(nums):
self.update(i, x)

def update(self, index: int, val: int) -> None:
delta = val - self.nums[index]
self.nums[index] = val
i = index + 1
while i < len(self.tree):
self.tree[i] += delta
i += i & -i

def prefixSum(self, i: int) -> int:
s = 0
while i:
s += self.tree[i]
i &= i - 1 # i -= i & -i 的另一种写法
return s

def sumRange(self, left: int, right: int) -> int:
return self.prefixSum(right + 1) - self.prefixSum(left)

$O(n)$的做法

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
class NumArray:
__slots__ = 'nums', 'tree'

def __init__(self, nums: List[int]):
n = len(nums)
tree = [0] * (n + 1)
for i, x in enumerate(nums, 1): # i 从 1 开始
tree[i] += x
nxt = i + (i & -i) # 下一个关键区间的右端点
if nxt <= n:
tree[nxt] += tree[i]
self.nums = nums
self.tree = tree

def update(self, index: int, val: int) -> None:
delta = val - self.nums[index]
self.nums[index] = val
i = index + 1
while i < len(self.tree):
self.tree[i] += delta
i += i & -i

def prefixSum(self, i: int) -> int:
s = 0
while i:
s += self.tree[i]
i &= i - 1 # i -= i & -i 的另一种写法
return s

def sumRange(self, left: int, right: int) -> int:
return self.prefixSum(right + 1) - self.prefixSum(left)

题单