给定数组中所有数对的位或和

问题描述

对于长度为 $n$ 的数组 $a$,如何计算所有数对的或运算结果之和,即

$$ \sum_{i=1}^{n}\sum_{j=i+1}^n a_i|a_j, $$

如对于数组 [1,2,3],结果为 (1|2)+(1|3)+(2|3)=3+3+3=9

问题分析

最直接的方法就是遍历所有结果:

1
2
3
def pair_or_sum(nums: List[int]) -> int:
    n = len(nums)
    return sum(nums[i] | nums[j] for i in range(n) for j in range(i + 1, n))

但是这样做的时间复杂度为 $\mathcal{O}(n^2)$,事实上存在 $\mathcal{O}(n)$ 的方法。

注意到 $a|b$ 的结果不会减少 $b$,而是在 $b$ 的基础上,加上二进制表示中所有 $a$ 为 1 而 $b$ 为 0 的位的位权。

如 $a|b=1|2=(01)_2|(10)_2=b+(1)\times(1\ll0)=3$。

所以对于第 $i$ 位,假设该位为 0 和为 1 的数量分别为 $b_0,b_1$,那么该位对结果的贡献为:

$$ (1\ll i)\times\left(\frac{b_1\times(b_1-1)}{2}+b_0 \times b_1\right) $$

$1\ll i$ 表示左移 $i$ 位,即第 $i$ 位的权重。

1
2
3
4
5
6
7
8
def pair_or_sum(nums: List[int]) -> int:
    ans = 0
    for i in range(31):
        bit = [0, 0]
        for x in nums:
            bit[x >> i & 1] += 1
        ans += (1 << i) * ((bit[1] - 1) * bit[1] // 2 + bit[0] * bit[1])
    return ans