简介
快速沃尔什变换用于解决位运算卷积问题,此类问题形如如下形式:
对于给定的长度为 的数组 和 ,求一个长度为 的数组 ,满足
其中 是一种位运算,例如按位与,按位或,按位异或。
直接朴素计算的时间复杂度为 ,使用快速沃尔什变换可以在 的时间复杂的内完成。
具体来说,考虑类似于 FFT 的思想,分为 3 步:
将 和 数组进行沃尔什变换,得到数组 和 。
将数组 和 对位相乘,得到数组 。
将 进行沃尔什逆变换,得到数组 。
按位或变换
当运算为按位或运算时,可以设 ,根据 且 等价于 可以推导出:
给定 求 可以对下标按位考虑。上代码:
c++
void orTrans(LL *a,LL type){
for(int l=2;l<=n;l<<=1){ // 枚举考虑到第几个二进制位
int k=l>>1;
for(int i=0;i<n;i+=l){
for(int j=0;j<k;++j){ // 枚举该二进制位及更高位的值为 i+j
(a[i+j+k]+=a[i+j]*type)%=mod;
// 那么比枚举的位更低的位,若为 1,则可以合并原来为 0 的位
}
}
}
}
逆变换时带入 type 为 -1 即可。
按位与变换
同理或运算,类比可得,设 ,类似地得到变换方法:
c++
void andTrans(LL *a,LL type){
for(int l=2;l<=n;l<<=1){
int k=l>>1;
for(int i=0;i<n;i+=l){
for(int j=0;j<k;++j){
(a[i+j]+=a[i+j+k]*type)%=mod;
}
}
}
}
按位异或运算
定义运算 为 , 为按位异或运算。
设 ,可以验证,仍然满足上述条件。
具体来说, 运算具有性质:。
证明考虑 运算的意义: 与 二进制下同为 的位数的奇偶性。分类讨论:考虑某一个二进制位,
若 和 这一位均为 ,那么 与 异或后这一位为 ,故右式中 与 不会同时为 ,贡献为 。
考虑左式,若 这一位为 ,则两者均为 ,不产生贡献;若 这一位为 ,左式中 和 都会因为多了一个同为 的位数而奇偶性改变,异或结果为 ,也不改变。
其它情况分类讨论,同理可得。
c++
void xorTrans(LL *a,LL type){
for(int l=2;l<=n;l<<=1){
int k=l>>1;
for(int i=0;i<n;i+=l){
for(int j=0;j<k;++j){
(a[i+j]+=a[i+j+k])%=mod;
a[i+j+k]=(a[i+j]+mod-a[i+j+k]+mod-a[i+j+k])%mod;
(a[i+j]*=type)%=mod;
(a[i+j+k]*=type)%=mod;
if(a[i+j+k]<0)a[i+j+k]+=mod;
}
}
}
}
逆变换时 type=。
下面给出 洛谷模板题 的代码
Sample Code(C++)
c++
#include<bits/stdc++.h>
using namespace std;
using LL=long long ;
const int N=1.4e5,mod=998244353;
int n;
void orTrans(LL *a,LL type){
for(int l=2;l<=n;l<<=1){
int k=l>>1;
for(int i=0;i<n;i+=l){
for(int j=0;j<k;++j){
(a[i+j+k]+=a[i+j]*type)%=mod;
}
}
}
}
void andTrans(LL *a,LL type){
for(int l=2;l<=n;l<<=1){
int k=l>>1;
for(int i=0;i<n;i+=l){
for(int j=0;j<k;++j){
(a[i+j]+=a[i+j+k]*type)%=mod;
}
}
}
}
void xorTrans(LL *a,LL type){
for(int l=2;l<=n;l<<=1){
int k=l>>1;
for(int i=0;i<n;i+=l){
for(int j=0;j<k;++j){
(a[i+j]+=a[i+j+k])%=mod;
a[i+j+k]=(a[i+j]+mod-a[i+j+k]+mod-a[i+j+k])%mod;
(a[i+j]*=type)%=mod;
(a[i+j+k]*=type)%=mod;
if(a[i+j+k]<0)a[i+j+k]+=mod;
}
}
}
}
void cpy(LL *a,LL *b){
for(int i=0;i<n;++i)a[i]=b[i];
}
void print(LL *a){
for(int i=0;i<n;++i)printf("%lld ",a[i]);
putchar(10);
}
LL a[N],b[N];
int main(){
scanf("%d",&n);
n=1<<n;
for(int i=0;i<n;++i)scanf("%lld",&a[i]);
for(int i=0;i<n;++i)scanf("%lld",&b[i]);
static LL A[N],B[N];
cpy(A,a);
cpy(B,b);
orTrans(A,1);
orTrans(B,1);
for(int i=0;i<n;++i)A[i]=A[i]*B[i]%mod;
orTrans(A,mod-1);
print(A);
cpy(A,a);
cpy(B,b);
andTrans(A,1);
andTrans(B,1);
for(int i=0;i<n;++i)A[i]=A[i]*B[i]%mod;
andTrans(A,mod-1);
print(A);
cpy(A,a);
cpy(B,b);
xorTrans(A,1);
xorTrans(B,1);
for(int i=0;i<n;++i)A[i]=A[i]*B[i]%mod;
xorTrans(A,499122177);
print(A);
return 0;
}