概述
\(FWT\)是用來處理集合卷積的問題,也就是求解\(f(n)\sum\limits_{i|j=n}f(i)f(j)\)型別的問題,其中或運算可以改為\(\otimes,\&\),
尋找點值
因為總是看不下去那么長的推導,所以每次都是看到一半,然后就在加上自己的一點理解,簡單推導一下吧(背過結論就行)
以或運算為例,為什么說是集合卷積呢,因為或運算等價于求集合并,也就是求\(f(n)\sum\limits_{i\cup j=n}f_1(i)f_2(j)\) ,
那么我們類似于\(FFT\),先將他轉化為點值,然后進行乘法運算后,在轉換回來,
如何轉化為點值呢,或者說他的點值長什么樣子呢,
我們求一個\(g(n)=\sum\limits_{s\subseteq n}f(n)\),然后我們求一下\(g(n)\)的乘積,
我們令\(g_1\)表示\(f_1\)轉化后的結果,\(g_2\)表示\(f_2\)轉化或的結果,\(g\)表示卷積\(f\)轉化后的結果,
也就是\(g_1(n)g_2(n)=\sum\limits_{s_1\subseteq n}f_1(s_1)\sum\limits_{s_2\subseteq n}f_2(s_2)=\sum\limits_{s_1,s2\subseteq n}f_1(s_1)f_2(s_2)=\sum\limits_{s_1\cup s_2\subseteq n}f_1(s_1)f_2(s_2)=g(n)\)
所以\(g(n)=\sum\limits_{s\subseteq n}f(n)\)就是我們要的點值運算式!
相互轉化
有了點值之后我們還需要在點值與多項式之間相互轉化,那么應該怎么轉化呢,
其實很簡單,觀察\(g(n)=\sum\limits_{s\subseteq n}f(n)\),這個式子,其實就是一個高維前綴和嘛,,
然后轉化回去同樣的來個高維差分就ok了,
高維前綴和代碼如下:
void fwt_or(int *a,int xs) {
for(int i = 0;i < n;++i)
for(int j = 0;j < (1 << n);++j)
if(!((j >> i) & 1))
a[j | (1 << i)] += a[j];
}
對于另外兩種運算
對于\(\otimes\)和\(\&\),與\(|\)類似,也有不同之處,因為\(\&\)表示的是集合交,所以他是列舉超集和而不是子集和,
至于\(\otimes\),背板子吧我也不會推導啊qwq
板子
板子里面,\(xs=1\)時表示\(FWT\),即將多項式轉化為點值,\(xs=-1\)時表示\(IFWT\),即將點值轉化回多項式,
或運算
void fwt_or(int *a,int xs) {
for(int i = 0;i < n;++i)
for(int j = 0;j < (1 << n);++j)
if(!((j >> i) & 1))
a[j | (1 << i)] += xs * a[j];
}
and運算
void fwt_and(int *a,int xs) {
for(int i = 0;i < n;++i)
for(int j = 0;j < (1 << n);++j)
if(!((j >> i) & 1))
a[j] += xs * a[j | (1 << i)];
}
異或運算
void fwt_xor(int *a,int xs) {
for(int i = 0;i < n;++i) {
for(int j = 0;j < (1 << n);++j) {
if(!((j >> i) & 1)) {
int l = a[j],r = a[j | (1 << i)];
a[j] = l + r;a[j] %= mod;
a[j | (1 << i)] = l - r;a[j | (1 << i)] %= mod;
}
}
}
if(xs == -1) {
int inv = qm(1 << n,mod - 2);
for(int i = 0;i < (1 << n);++i)
a[i] = 1ll * a[i] * inv % mod;
}
}
小技巧
在很多題目中,需要進行多次\(FWT\)運算,我們不需要每次將陣列來回轉化,只要先將多項式轉化為點值,然后對點值進行快速冪運算,最后在轉化回去就行,
模板題
luogu4717
/*
* @Author: wxyww
* @Date: 2020-04-26 08:03:27
* @Last Modified time: 2020-04-26 08:43:59
*/
#include<cstdio>
#include<iostream>
#include<cstdlib>
#include<cstring>
#include<algorithm>
#include<queue>
#include<vector>
#include<ctime>
using namespace std;
typedef long long ll;
const int N = 1 << 20,mod = 998244353;
ll read() {
ll x = 0,f = 1;char c = getchar();
while(c < '0' || c > '9') {
if(c == '-') f = -1; c = getchar();
}
while(c >= '0' && c <= '9') {
x = x * 10 + c - '0'; c = getchar();
}
return x * f;
}
int A[N],B[N],n;
void fwt_and(int *a,int xs) {
for(int i = 0;i < n;++i) {
for(int j = 0;j < (1 << n);++j) {
if(!((j >> i) & 1)) {
a[j] += xs * a[j | (1 << i)];
a[j] %= mod;
}
}
}
}
void fwt_or(int *a,int xs) {
for(int i = 0;i < n;++i) {
for(int j = 0;j < (1 << n);++j) {
if(!((j >> i) & 1)) {
a[j | (1 << i)] += xs * a[j];
a[j | (1 << i)] %= mod;
}
}
}
}
ll qm(ll x,ll y) {
ll ret = 1;
for(;y;y >>= 1,x = x * x % mod)
if(y & 1) ret = ret * x % mod;
return ret;
}
void fwt_xor(int *a,int xs) {
for(int i = 0;i < n;++i) {
for(int j = 0;j < (1 << n);++j) {
if(!((j >> i) & 1)) {
int l = a[j],r = a[j | (1 << i)];
a[j] = l + r;a[j] %= mod;
a[j | (1 << i)] = l - r;a[j | (1 << i)] %= mod;
}
}
}
if(xs == -1) {
int inv = qm(1 << n,mod - 2);
for(int i = 0;i < (1 << n);++i) {
a[i] = 1ll * a[i] * inv % mod;
}
}
}
int tmp1[N],tmp2[N];
int main() {
n = read();
for(int i = 0;i < (1 << n);++i) A[i] = read();
for(int i = 0;i < (1 << n);++i) B[i] = read();
memcpy(tmp1,A,sizeof(tmp1));
memcpy(tmp2,B,sizeof(tmp2));
fwt_or(tmp1,1);fwt_or(tmp2,1);
for(int i = 0;i < (1 << n);++i) tmp1[i] = 1ll * tmp1[i] * tmp2[i] % mod;
fwt_or(tmp1,-1);
for(int i = 0;i < (1 << n);++i) printf("%d ",(tmp1[i] + mod) % mod);puts("");
memcpy(tmp1,A,sizeof(tmp1));
memcpy(tmp2,B,sizeof(tmp2));
fwt_and(tmp1,1);fwt_and(tmp2,1);
for(int i = 0;i < (1 << n);++i) tmp1[i] = 1ll * tmp1[i] * tmp2[i] % mod;
fwt_and(tmp1,-1);
for(int i = 0;i < (1 << n);++i) printf("%d ",(tmp1[i] + mod) % mod);puts("");
memcpy(tmp1,A,sizeof(tmp1));
memcpy(tmp2,B,sizeof(tmp2));
fwt_xor(tmp1,1);fwt_xor(tmp2,1);
for(int i = 0;i < (1 << n);++i) tmp1[i] = 1ll * tmp1[i] * tmp2[i] % mod;
fwt_xor(tmp1,-1);
for(int i = 0;i < (1 << n);++i) printf("%d ",(tmp1[i] + mod) % mod);puts("");
return 0;
}
轉載請註明出處,本文鏈接:https://www.uj5u.com/qita/55353.html
標籤:其他
上一篇:HashMap淺析
