所有可能子集的异或和

问题描述

对于长度为 $n$ 的数组 $nums$,分别计算 $nums$ 的每个子集内元素异或的结果,然后返回结果之和。

示例

1
2
3
4
5
输入:a = [1, 2, 3]
输出:
a[0] + a[1] + a[2] +
a[0] ^ a[1] + a[0] ^ a[2] + a[1] ^ a[2] +
a[0] ^ a[1] ^ a[2] = 12

问题分析

首先想到的就是枚举所有子集,然后计算结果。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
int subset_xor_sum(vector<int> &nums) {
    int n = nums.size(), N = 1 << n;
    int ans = 0;
    for (int i = 0; i < N; ++i) {
        int t = 0;
        for (int j = 0; j < n; ++j) {
            if (i >> j & 1) {
                t ^= nums[j];
            }
        }
        ans += t;
    }
    return ans;
}

枚举子集的时候其实有重复计算的过程(子集的子集),可以通过动态规划来减少重复计算。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
int subset_xor_sum(vector<int> &nums) {
    vector<int> S{0};
    for (int &x : nums) {
        int n = S.size();
        for (int i = 0; i < n; ++i) {
            S.emplace_back(S[i] ^ x);
        }
    }
    return accumulate(S.begin(), S.end(), 0);
}
1
2
3
4
5
6
7
8
int subset_xor_sum(vector<int> &nums) {
    int N = 1 << nums.size();
    vector<int> dp(N);
    for (int i = 1; i < N; ++i) {
        dp[i] = nums[log2(i & -i)] ^ dp[i & (i - 1)];
    }
    return accumulate(dp.begin(), dp.end(), 0);
}

这些方法都比较常规,更神奇的是利用位运算可以使得时间复杂度降到 $\mathcal{O}(n)$。

单独考虑二进制表示中每一位的情况,

假设 $nums$ 中第 $i$ 位上为 $1$ 的数有 $a$,为 $0$ 的数有 $b$ 个,显然 $a+b=n$。

我们知道想要某一位异或的结果为 $1$,那么必须要有 奇数 个 $1$ 才行,$0$ 的个数没有影响。

所以必须在 $a$ 个 $1$ 中选出奇数个 $1$,那么在 $a$ 个数中选出奇数个数有多少种组合呢?

用 $C_n^m$ 表示组合数,即从 $n$ 个数中选出 $m$ 个数的组合个数。

那么在 $a$ 个数中选出奇数个数的个数为 $C_a^1+C_a^3+C_a^5+...$。

我们知道二项式定理:

$$ (x+y)^n = C_n^0 x^0y^n + C_n^1 x^1y^{n-1}+\dots+ C_n^n x^ny^0 $$

令 $x=1,y=1$,代入公式中,

$$ (1+1)^n=2^n=C_n^1+C_n^2+\dots+C_n^n $$

令 $x=1,y=-1$,代入公式中,

  • 如果 $n$ 为偶数,则有:
$$ (1+(-1))^n=0=C_n^0-C_n^1+C_n^2-C_n^3+... $$

即选出 奇数个数的组合个数选出偶数个数的组合 个数相同。

同理,$n$ 为奇数时结论一致。

所以在 $n$ 个数中选出奇数个数的组合个数为 $2^{n-1}$。

所以当第 $i$ 位上有 $a$ 个 $1$,$b$ 个 $0$ 时,对结果的贡献为:

$$ 2^i \times 2^{a-1} \times 2 ^ b = 2^i \times 2 ^ {n-1} $$

所以我们知道了只要 $a \neq 0$,那么第 $i$ 位的贡献就是 $2^i \times 2 ^ {n-1}$。

1
2
3
int subset_xor_sum(vector<int> &nums) {
    return accumulate(nums.begin(), nums.end(), 0, bit_or<int>()) << ((int)nums.size() - 1);
}

参考