FFT&NTT

FFT

学习了FFT 法法塔 感受到了人和人之间是有不少差距的Orz。 首先要介绍一个单位根,说的玄妙,其实就把它看做一个角度好了,然后在坐标系里面以自己角的大小为步长,进行旋转,然后它会回到自己就完了。如果我口胡太严重了,就去看这篇博客好了,写的很好,但是我感觉它代码有问题。 首先介绍一下单位根($\omega_n^1$)的一些性质: $\omega_n^k = (\omega_n^1)^k$ $\omega_{2n}^{2k} = \omega_n^k$ $\omega_n^0 = 1$ $\omega_n^k*\omega_n^j = \omega_n^{j+k}$ 然后就差不多使用如下性质。然后再说FFT。FFT是用来求卷积的,所谓的卷积就是由两个函数生成第三个函数的运算。即$h(x) = \int\limits_{-\infty}^{\infty}f(i)*g(x-i)$,然后这玩意和傅里叶变化有很大关系,可以简化计算。注意由于计算机只能处理非连续的数据,所以需要把上面的积分符号换成$\sum$。FFT的过程如下 $F(x) = a_0+a_0x^1+…+a_nx^n$ 然后按照奇偶项拆开 $F(x) = (a_0+a_2x^2…)+x*(a_1x^1+…)$ 化为,然后发现这玩意可以分治 $F(x) = FL(x^2)+x*FR(x^2)$ 把x代为单位根: \(F(x) = FL(x^2)+x*FR(x^2)\\ F(w_n^k) = FL((w_{n}^k)^2)+w_n^kFR((w_{n}^k)^2)\\ F(w_n^k) = FL(w_{n/2}^k)+w_n^kFR(w_{n/2}^k)\\ solve:w_n^0...w_n^{n/2-1}\\ to: w_n^{n/2-1}...w_n^{n-1}\\\) 假设$k<n/2$,由于: \(w_n^k = w_n^{k\%n}\) 所以可以做如下转化 \(F(w_n^{k+n/2}) = FL((w_n^{k+n/2})^2)+w_n^{k+n/2}FR((w_n^{k+n/2})^2)\\ bacause\\ (w_n^k)^j = w_n^{kj}\\ therefore\\ =FL(w_n^{2k+n})+w_n^{k+n/2}FR(w_n^{2k+n})\\ w_n^k = w_n^{k\%n}\\ = FL(w_n^{2k})+w_n^{k+n/2}FR(w_n^{2k})\\ w_{2n}^{2k} = w_n^k\\ = FL(w_{n/2}^k)+w_n^{k+n/2}FR(w_{n/2}^k)\\ w_n^{k+n/2} = -w_n^{k}\\ therefore \\ FL(w_{n/2}^k)-w_n^{k}FR(w_{n/2}^k)\) 然后计算就行了。最后有一个玄学蝴蝶优化,大概就是不写递归的感觉,因为分治就像一个线段树,于是我把每个区间单独取出来处理一下,就免了递归,方便卡常 板子代码:

# include <iostream>
# include <cstring>
# include <cstdio>
# include <cmath>
using namespace std;
const int MAXN = (int)4e6+10;
const double PI = acos(-1);
struct complex{
	double x,y;
	complex(double a = 0.0,double b = 0.0):x(a),y(b){}
	complex operator+(complex b){return complex(x+b.x,y+b.y);}
	complex operator-(complex b){return complex(x-b.x,y-b.y);}
	complex operator*(complex b){return complex(x*b.x-y*b.y,x*b.y+y*b.x);}
}f[MAXN],g[MAXN];
int n,m,len;
int R[MAXN],lg;
void init(){for(int i = 0;i<len;i++)R[i] = (R[i>>1]>>1)|((i&1)<<(lg-1));}
void DFT(complex *f,int op){
	for(int i = 0;i<len;i++)if(i<R[i])swap(f[i],f[R[i]]);
	for(int j = 1;j<len;j<<=1){
		complex base = complex(cos(PI/j),op*sin(PI/j));
		for(int k = 0;k<len;k+=(j<<1)){
			complex t(1,0);
			for(int l = 0;l<j;l++,t = t*base){
				complex nx = f[k+l],ny = f[j+k+l]*t;
				f[k+l] = nx+ny;
				f[k+j+l] = nx-ny;
			}
		}
	}
}


int main(){
	scanf("%d%d",&n,&m);
	for(int i = 0;i<=n;i++)scanf("%lf",&f[i].x);
	for(int i = 0;i<=m;i++)scanf("%lf",&g[i].x);
	m+=n;lg = 0;
	for(n = 1;n<=m;n<<=1)lg++;
	len = n;init();
	DFT(f,1);DFT(g,1);
	for(int i = 0;i<=n;i++)f[i] = f[i]*g[i];
	DFT(f,-1);
	for(int i = 0;i<=m;i++)printf("%d ",(int)(fabs(f[i].x)/n+0.5));
	return 0;
}

NTT

上面的做法因为要使用三角函数和复数,复数又要用到double,卡精度卡的厉害。于是考虑能不能优化。现引出原根的概念: 设$p = r*2^k+1$k足够大,设原根为$g$,则$g^{0~n-1}$为互不相同的整数。然后定义剩余系下的原根$\omega_n^1 = g^{(p-1)/n}$可以证明它是一个整数,且$g^{0~n-1}$为互不相同的整数。 然后这个原根拥有和FFT的复数一样的性质。用就行了。但是要注意一点,不同的模数对应不同的原根,如998244353对应的原根是3,3在模998244353意义下的的逆元是332748118。原根的求法如下: 原根的判定:如果$g^k\equiv 1(mod p)$中k的最小解为(mod-1),则g是一个原根。 求法就是枚举1-100(因为大部分原根都小于100),然后判断就行了。判断的过程就是枚举(mod-1)的因数,如果它满足上述条件,说明这个数不是模mod意义下的原根。 然后代码里面有一些细节,比如FFT里面的

complex base = comple(cos(PI/j),op*(PI/j))

来实现DFT和IDFT的过程,但是使用NTT就更简单了,直接计算原根的k次方或者原根的逆元的k次方来实现。具体代码就是:

long long base = fp(op==1?g:revg,mod-1/(j*2));

至于为什么j要乘以2,我也没搞清楚,这两天状态太差了。 代码实现如下:

# include <iostream>
# include <cstring>
# include <cstdio>
# include <cmath>
using namespace std;
const long long mod = 998244353;
const int MAXN = 4e6+10;
const long long G = 3;
int n,m,R[MAXN];
long long f[MAXN],g[MAXN],invG = 332748118;
inline int read() { 
    char c = getchar(); int x = 0, f = 1;
    while(c < '0' || c > '9') {if(c == '-') f = -1; c = getchar();}
    while(c >= '0' && c <= '9') x = x * 10 + c - '0', c = getchar();
    return x * f;
}
long long fp(long long b,long long p = mod-2){
	long long ans = 1;
	while(p){
		if(p&1)ans = ans*b%mod;
		b = b*b%mod;
		p>>=1;
	}
	return ans;
}
int len,lg;
void init(){
	for(int i = 0;i<len;i++)R[i] = (R[i>>1]>>1)|((i&1)<<(lg-1));
}
void NTT(long long *f,int op){
	for(int i = 0;i<len;i++)if(i<R[i])swap(f[i],f[R[i]]);
	for(int j = 1;j<len;j<<=1){
		long long base = fp(op==1?G:invG,(mod-1)/(j<<1));
		for(int k = 0;k<len;k+=(j<<1)){
			long long t = 1;
			for(int l = 0;l<j;l++,t = (t*base)%mod){
				int nx = f[k+l],ny = f[j+k+l]*t%mod;
				f[k+l] = (nx+ny)%mod;
				f[j+k+l] = (nx-ny+mod)%mod;
			}
		}
	}
}
int main(){
	scanf("%d%d",&n,&m);
	for(int i = 0;i<=n;i++)f[i] = read();
	for(int i = 0;i<=m;i++)g[i] = read();
	m+=n;
	for(len = 1;len<=m;len<<=1)lg++;init();
	NTT(f,1);NTT(g,1);
	for(int i = 0;i<len;i++)f[i] = f[i]*g[i]%mod;
	NTT(f,-1);
	long long rev = fp(len,mod-2);
	for(int i = 0;i<=m;i++)printf("%lld ",(f[i]*rev%mod+mod)%mod);
	return 0;
}