Contents

高效的 Python 切片赋值

引言

对于某些算法题,可能需要维护一个有序的数据结构,而且需要查询中间满足某个条件的数据(如果是最小或最大值的话使用堆结构即可),C++ 中有 set/multiset/map/multimap 容器可以使用,但是 Python 标准库中没有类似的数据结构,虽然在 LeetCode 中可以使用第三方库 sortedcontainers

尽管 Python 标准库中没有类似的结构,但是结合使用 bisect 库和 切片赋值 可以实现很高的效率。

用法

1
2
3
4
5
6
7
def insert(nums: List[int], k: int, val: int):
    # 在 nums 第 k 个位置处插入 val
    nums[k:k] = [val]

def delete(nums: List[int], k: int):
    # 删除 nums[k]
    nums[k : k + 1] = []

测试

假设 arr: List[int] 是一个随机数组,现在我们要将它变为一个有序数组,当然我们可以直接使用 list.sort 或者 sorted 进行排序,但是这里我们用排序任务来比较 list.insert(i, x)[i:i]=[x] 的效率。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
def test1(arr: List[int]) -> None:
    sorted_arr = []
    for x in arr:
        k = bisect_left(sorted_arr, x)
        sorted_arr[k:k] = [x]

def test2(arr: List[int]) -> None:
    sorted_arr = []
    for x in arr:
        k = bisect_left(sorted_arr, x)
        sorted_arr.insert(k, x)

在上面的代码中,test1 使用了切片赋值,test2 使用了 list.insert 方法。

我们用下面的代码来统计两种方法对 10 组随机生成的 100000 ($10^5$) 个数进行排序的执行用时。

1
2
3
4
5
def run(test: Callable[[List[int]], None], samples: List[int]) -> List[float]:
    res = []
    for sample in samples:
        res.append(timeit(lambda: test(sample), number=10))
    return res
完整测试代码
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
from bisect import bisect_left
from copy import deepcopy
from random import randint
from timeit import timeit
from typing import Callable, List

import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns

sns.set_style("whitegrid")
plt.rcParams["figure.figsize"] = (12, 8)


def generate_samples(row: int, col: int) -> List[List[int]]:
    return [[randint(1, 100) for _ in range(col)] for _ in range(row)]


def test1(arr: List[int]) -> None:
    sorted_arr = []
    for x in arr:
        k = bisect_left(sorted_arr, x)
        sorted_arr[k:k] = [x]
    # assert all(b >= a for a, b in zip(sorted_arr[:-1], sorted_arr[1:]))


def test2(arr: List[int]) -> None:
    sorted_arr = []
    for x in arr:
        k = bisect_left(sorted_arr, x)
        sorted_arr.insert(k, x)
    # assert all(b >= a for a, b in zip(sorted_arr[:-1], sorted_arr[1:]))


def run(test: Callable[[List[int]], None], samples) -> List[float]:
    res = []
    for sample in samples:
        res.append(timeit(lambda: test(sample), number=10))
    return res

ndim = 100000
num_samples = 10
data = generate_samples(num_samples, ndim)
data_copy1 = deepcopy(data)
data_copy2 = deepcopy(data)

res1 = run(test1, data_copy1)
res2 = run(test2, data_copy2)

res = []
labels = []

for a, b in zip(res1, res2):
    res.append(a)
    res.append(b)
    labels.append("test1")
    labels.append("test2")

data = pd.DataFrame(data={"Data": list(range(len(res))), "Time": res, "Label": labels})
sns.barplot(x="Data", y="Time", hue="Label", data=data, dodge=False)
plt.savefig("result.png", bbox_inches="tight")

测试结果

下图是两种方法的执行用时统计:

/posts/python-slice-assignment/result.png
测试结果统计

可以看出 list.insert 的执行用时差不多是使用切片赋值技术的两倍。

原因分析

Stack Overflow: Why is slice assignment faster than list.insert?

例题

对于经典的逆序对问题,虽然使用归并排序等算法的时间复杂度为 $\mathcal{O}(n\log(n))$,但是在 Python 中使用切片赋值技术,代码十分简单:

1
2
3
4
5
6
7
8
class Solution:
    def reversePairs(self, nums: List[int]) -> int:
        ans, S = 0, []
        for x in nums:
            k = bisect_right(S, x)
            ans += len(S) - k
            S[k:k] = [x]
        return ans

其余各题代码:

295. 数据流的中位数

https://www.sotsog.cn/data-structures-and-algorithms/algorithms/mathematics/binary-exponentiation/

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
class MedianFinder:
    def __init__(self):
        self.S = []

    def addNum(self, num: int) -> None:
        k = bisect_left(self.S, num)
        self.S[k:k] = [num]

    def findMedian(self) -> float:
        n = len(self.S)
        return (self.S[n // 2] + self.S[(n - 1) // 2]) / 2
327. 区间和的个数

https://leetcode.cn/problems/count-of-range-sum/

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
class Solution:
    def countRangeSum(self, nums: List[int], lower: int, upper: int) -> int:
        n = len(nums)
        ans = 0
        S = [0]
        cumsum = 0
        for x in nums:
            cumsum += x
            ans += bisect_right(S, cumsum - lower) - bisect_left(S, cumsum - upper)
            k = bisect_left(S, cumsum)
            S[k:k] = [cumsum]
        return ans
480. 滑动窗口中位数

https://leetcode.cn/problems/sliding-window-median/

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
class Solution:
    def medianSlidingWindow(self, nums: List[int], k: int) -> List[float]:
        ans, S, n = [], [], len(nums)
        for i in range(n):
            j = bisect_left(S, nums[i])
            S[j:j] = [nums[i]]
            if len(S) == k:
                ans.append((S[k // 2] + S[(k - 1) // 2]) / 2)
                j = bisect_left(S, nums[i - k + 1])
                S[j : j + 1] = []
        return ans
493. 翻转对

https://leetcode.cn/problems/reverse-pairs/

1
2
3
4
5
6
7
8
9
class Solution:
    def reversePairs(self, nums: List[int]) -> int:
        ans, S = 0, []
        for x in nums:
            k = bisect_right(S, 2 * x)
            ans += len(S) - k
            k = bisect_left(S, x)
            S[k:k] = [x]
        return ans
1847. 最近的房间

https://leetcode.cn/problems/closest-room/

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
class Solution:
    def closestRoom(self, rooms: List[List[int]], queries: List[List[int]]) -> List[int]:
        n = len(queries)
        ans = [-1] * n
        index = list(range(n))
        S = []

        index.sort(key=lambda i: queries[i][1], reverse=True)
        rooms.sort(key=lambda x: x[1])

        for i in index:
            preferred, min_size = queries[i]
            while rooms and rooms[-1][1] >= min_size:
                k = bisect_left(S, rooms[-1][0])
                S[k:k] = [rooms[-1][0]]
                rooms.pop()
            if S:
                k = bisect_left(S, preferred)
                if k != len(S):
                    ans[i] = S[k]
                if k != 0:
                    if ans[i] == -1 or abs(S[k - 1] - preferred) <= abs(ans[i] - preferred):
                        ans[i] = S[k - 1]
        return ans

执行用时统计:

问题 执行用时 击败
剑指 Offer 51. 数组中的逆序对 824ms 99.06%
295. 数据流的中位数 260ms 36.80%
327. 区间和的个数 2736ms 12.28%
480. 滑动窗口中位数 92ms 86.85%
493. 翻转对 1060ms 99.20%
1847. 最近的房间 332ms 100.00%

需要维护一个有序数组的题大多属于困难题,使用切片赋值的虽然可以用来 轻松 过题,但是也应该意识到尽管切片赋值技术的底层实现十分高效,但是时间复杂度仍然是线性的。