快速沃尔什变换

简介

快速沃尔什变换用于解决位运算卷积问题,此类问题形如如下形式:

对于给定的长度为 2n2^{n} 的数组 aabb,求一个长度为 2n2^{n} 的数组 cc,满足

ci=i=jkajbkc_{i}=\sum_{i=j\oplus k}a_{j}b_{k}

其中 \oplus 是一种位运算,例如按位与,按位或,按位异或。

直接朴素计算的时间复杂度为 O(4n)O(4^n),使用快速沃尔什变换可以在 O(n2n)O(n2^n) 的时间复杂的内完成。

具体来说,考虑类似于 FFT 的思想,分为 3 步:

  1. aabb 数组进行沃尔什变换,得到数组 AABB

  2. 将数组 AABB 对位相乘,得到数组 CC

  3. CC 进行沃尔什逆变换,得到数组 cc

按位或变换

当运算为按位或运算时,可以设 Ai=ij=iajA_i=\sum_{i\mid j=i}a_j,根据 ij=ii\mid j=iik=ii\mid k=i 等价于 (jk)i=i(j\mid k)\mid i=i 可以推导出:

AiBi=ij=iajik=ibk=ij=iik=iajbk=(jk)i=iajbk=Ci\begin{aligned} A_iB_i&=\sum_{i\mid j=i}a_j\sum_{i\mid k=i}b_k\\ &=\sum_{i\mid j=i}\sum_{i\mid k=i}a_jb_k\\ &=\sum_{(j\mid k)\mid i=i}a_j b_k\\ &=C_i \end{aligned}

给定 aaAA 可以对下标按位考虑。上代码:

c++
void orTrans(LL *a,LL type){
  for(int l=2;l<=n;l<<=1){ // 枚举考虑到第几个二进制位
    int k=l>>1;
    for(int i=0;i<n;i+=l){
      for(int j=0;j<k;++j){ // 枚举该二进制位及更高位的值为 i+j
        (a[i+j+k]+=a[i+j]*type)%=mod;
        // 那么比枚举的位更低的位,若为 1,则可以合并原来为 0 的位
      }
    }
  }
}

逆变换时带入 type 为 -1 即可。

按位与变换

同理或运算,类比可得,设 Ai=i&j=iajA_i=\sum_{i\& j=i}a_j,类似地得到变换方法:

c++
void andTrans(LL *a,LL type){
  for(int l=2;l<=n;l<<=1){
    int k=l>>1;
    for(int i=0;i<n;i+=l){
      for(int j=0;j<k;++j){
        (a[i+j]+=a[i+j+k]*type)%=mod;
      }
    }
  }
}

按位异或运算

定义运算 \circab=popcnt(a&b)mod2a\circ b=\text{popcnt}(a\& b)\bmod 2\oplus 为按位异或运算。

Ai=ij=0ajij=1ajA_i=\sum_{i\circ j=0}a_j-\sum_{i\circ j=1}a_j,可以验证,仍然满足上述条件。

具体来说,\circ 运算具有性质:(ij)(ik)=i(jk)(i\circ j)\oplus(i\circ k)=i\circ(j\oplus k)

证明考虑 \circ 运算的意义:aabb 二进制下同为 11 的位数的奇偶性。分类讨论:考虑某一个二进制位,

  • jjkk 这一位均为 11,那么 jjkk 异或后这一位为 00,故右式中 iijjj\oplus j 不会同时为 11,贡献为 00

    考虑左式,若 ii 这一位为 00,则两者均为 00,不产生贡献;若 ii 这一位为 11,左式中 iji\circ jiki\circ k 都会因为多了一个同为 11 的位数而奇偶性改变,异或结果为 00,也不改变。

其它情况分类讨论,同理可得。

c++
void xorTrans(LL *a,LL type){
  for(int l=2;l<=n;l<<=1){
    int k=l>>1;
    for(int i=0;i<n;i+=l){
      for(int j=0;j<k;++j){
        (a[i+j]+=a[i+j+k])%=mod;
        a[i+j+k]=(a[i+j]+mod-a[i+j+k]+mod-a[i+j+k])%mod;
        (a[i+j]*=type)%=mod;
        (a[i+j+k]*=type)%=mod;
        if(a[i+j+k]<0)a[i+j+k]+=mod;
      }
    }
  }
}

逆变换时 type=12\frac{1}{2}

下面给出 洛谷模板题 的代码

Sample Code(C++)
c++
#include<bits/stdc++.h>
using namespace std;
using LL=long long ;
const int N=1.4e5,mod=998244353;
int n;
void orTrans(LL *a,LL type){
  for(int l=2;l<=n;l<<=1){
    int k=l>>1;
    for(int i=0;i<n;i+=l){
      for(int j=0;j<k;++j){
        (a[i+j+k]+=a[i+j]*type)%=mod;
      }
    }
  }
}
void andTrans(LL *a,LL type){
  for(int l=2;l<=n;l<<=1){
    int k=l>>1;
    for(int i=0;i<n;i+=l){
      for(int j=0;j<k;++j){
        (a[i+j]+=a[i+j+k]*type)%=mod;
      }
    }
  }
}
void xorTrans(LL *a,LL type){
  for(int l=2;l<=n;l<<=1){
    int k=l>>1;
    for(int i=0;i<n;i+=l){
      for(int j=0;j<k;++j){
        (a[i+j]+=a[i+j+k])%=mod;
        a[i+j+k]=(a[i+j]+mod-a[i+j+k]+mod-a[i+j+k])%mod;
        (a[i+j]*=type)%=mod;
        (a[i+j+k]*=type)%=mod;
        if(a[i+j+k]<0)a[i+j+k]+=mod;
      }
    }
  }
}
void cpy(LL *a,LL *b){
  for(int i=0;i<n;++i)a[i]=b[i];
}
void print(LL *a){
  for(int i=0;i<n;++i)printf("%lld ",a[i]);
  putchar(10);
}
LL a[N],b[N];
int main(){
  scanf("%d",&n);
  n=1<<n;
  for(int i=0;i<n;++i)scanf("%lld",&a[i]);
  for(int i=0;i<n;++i)scanf("%lld",&b[i]);
  static LL A[N],B[N];
  cpy(A,a);
  cpy(B,b);
  orTrans(A,1);
  orTrans(B,1);
  for(int i=0;i<n;++i)A[i]=A[i]*B[i]%mod;
  orTrans(A,mod-1);
  print(A);

  cpy(A,a);
  cpy(B,b);
  andTrans(A,1);
  andTrans(B,1);
  for(int i=0;i<n;++i)A[i]=A[i]*B[i]%mod;
  andTrans(A,mod-1);
  print(A);

  cpy(A,a);
  cpy(B,b);
  xorTrans(A,1);
  xorTrans(B,1);
  for(int i=0;i<n;++i)A[i]=A[i]*B[i]%mod;
  xorTrans(A,499122177);
  print(A);

  return 0;
}
初三回忆录
多项式基本运算