0%

树状数组

解决什么问题?

单点更新,计算前缀和。

原理

灵神讲解

理解与总结

结合灵神文章和视频,尝试总结一下。

在有单点更新的计算前缀和的问题当中,为了平衡「查询」和「更新」的时间复杂度,由每个数$i$的二进制长度为$O(log\ i)$得到启发,考虑将每个区间$[1,i]$划分成几个关键区间,每个区间的长度为$i$的二进制中的$i$们所代表的值。

同时,由于从$[1,1]$到$[1,i]$右端点每增加$1$的一步都只会新增加一个关键区间,所以一共有$i$个关键区间。因此,tree[i]就是代表的是右端点为$i$的关键区间的元素和。

如何查询$[1,i]$的元素和

$[1,i]$区间和元素和,可以看成是$[1,i-lowbit(i)]$(去掉$lowbit(i)$后的区间)和$[i-lowbit(i)+1,i]$两个区间的和。

其中,$[1,i-lowbit(i)]$是一个规模更小的子问题。

因此算法如下:

  1. 初始化元素和$s = 0$。
  2. tree[i]加入到$s$中。
  3. 将$i$去掉$lowbit[i]$,即更新为$i-lowbit$,重复上一步,直到$i$为0。

如何更新单点$i$

更新$i$不能只更新$i$,而是要更新含有$i$的所有关键区间。

哪些结点所代表的区间包含$i$呢?右端点为$i$的区间的父节点们(结合灵神文章的图)。而每个结点$i$和它的父节点之间的右端点相差的是$lowbit(i)$。

因此更新算法如下:

  1. 更新tree[i]
  2. 将$i$加上$lowbit[i]$,重复第一步,知道越界。

代码

lowbit 相关

lowbit:

1
lowbit = s & -s

解释:-s = ~s + 1,即s的负数是所有位按位取反末尾加一,得到的结果和原来的s相比除了最右边的1一样,其他都相反,那再和s按位与,就得到了lowbit

减去lowbit

1
i &= i - 1  # i -= i & -i 的另一种写法

下标问题

在代码中,tree开$n+1$位置,询和更新的时候,认为它是下标从$1$开始,即只用$[1,n]$部分。

而且,tree下标是必须从$1$开始,因为当$i=0$的时候,$lowbit(i)=0$,会发生死循环。

而外面(题目中)的数组的下标可以从$0$也可以从$1$开始,如果从$0$开始,则需要$+1$,从$1$开始就不用了。

模板

这是从空开始不断建的。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
# 模板来源 https://leetcode.cn/circle/discuss/mOr1u6/
class FenwickTree:
def __init__(self, n: int):
self.tree = [0] * (n + 1) # 使用下标 1 到 n

# a[i] 增加 val
# 1 <= i <= n
# 时间复杂度 O(log n)
def update(self, i: int, val: int) -> None:
t = self.tree
while i < len(t):
t[i] += val
i += i & -i

# 计算前缀和 a[1] + ... + a[i]
# 1 <= i <= n
# 时间复杂度 O(log n)
def pre(self, i: int) -> int:
t = self.tree
res = 0
while i > 0:
res += t[i]
i &= i - 1
return res

# 计算区间和 a[l] + ... + a[r]
# 1 <= l <= r <= n
# 时间复杂度 O(log n)
def query(self, l: int, r: int) -> int:
if r < l:
return 0
return self.pre(r) - self.pre(l - 1)

# 作者:灵茶山艾府
# 链接:https://leetcode.cn/discuss/post/3583665/fen-xiang-gun-ti-dan-chang-yong-shu-ju-j-bvmv/
# 来源:力扣(LeetCode)
# 著作权归作者所有。商业转载请联系作者获得授权,非商业转载请注明出处。

这是有初始值的$O(nlog\ n)$的做法

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
class NumArray:
__slots__ = 'nums', 'tree'

def __init__(self, nums: List[int]):
n = len(nums)
self.nums = [0] * n # 使 update 中算出的 delta = nums[i]
self.tree = [0] * (n + 1)
for i, x in enumerate(nums):
self.update(i, x)

def update(self, index: int, val: int) -> None:
delta = val - self.nums[index]
self.nums[index] = val
i = index + 1
while i < len(self.tree):
self.tree[i] += delta
i += i & -i

def prefixSum(self, i: int) -> int:
s = 0
while i:
s += self.tree[i]
i &= i - 1 # i -= i & -i 的另一种写法
return s

def sumRange(self, left: int, right: int) -> int:
return self.prefixSum(right + 1) - self.prefixSum(left)

这是$O(n)$的

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
class NumArray:
__slots__ = 'nums', 'tree'

def __init__(self, nums: List[int]):
n = len(nums)
tree = [0] * (n + 1)
for i, x in enumerate(nums, 1): # i 从 1 开始
tree[i] += x
nxt = i + (i & -i) # 下一个关键区间的右端点
if nxt <= n:
tree[nxt] += tree[i]
self.nums = nums
self.tree = tree

def update(self, index: int, val: int) -> None:
delta = val - self.nums[index]
self.nums[index] = val
i = index + 1
while i < len(self.tree):
self.tree[i] += delta
i += i & -i

def prefixSum(self, i: int) -> int:
s = 0
while i:
s += self.tree[i]
i &= i - 1 # i -= i & -i 的另一种写法
return s

def sumRange(self, left: int, right: int) -> int:
return self.prefixSum(right + 1) - self.prefixSum(left)

题单