hdu4914 Linear recursive sequence

时间:2021-12-11 07:58:12

用矩阵求解线性递推式通项

用fft优化矩阵乘法

首先把递推式求解转化为矩阵求幂,再利用特征多项式f(λ)满足f(A) = 0,将矩阵求幂转化为多项式相乘,

最后利用傅里叶变换的高效算法(迭代取代递归)(参见算法导论)解决。

 #include <cstdio>
#include <cstring>
#include <algorithm>
#include <map>
#include <string>
#include <vector>
#include <set>
#include <cmath>
#include <ctime>
#pragma comment(linker, "/STACK:102400000,102400000")
using namespace std;
#define lson (u << 1)
#define rson (u << 1 | 1)
typedef long long ll;
const double eps = 1e-;
const double pi = acos(-1.0);
const int maxn = 4e4 + ;
const int maxm = ;
const int mod = ;
const int inf = 0x3f3f3f3f; int n, a, b, p, q;
int size;
int f[maxn], g[maxn]; struct Complex{
double ii, ij;//ii::real, ij::image
Complex(double ii = , double ij = ) : ii(ii), ij(ij) {}
// Complex clear() { this->ii = this->ij = 0; }
Complex operator + (const Complex &rhs) const{
return Complex(ii + rhs.ii, ij + rhs.ij);
}
Complex operator - (const Complex &rhs) const{
return Complex(ii - rhs.ii, ij - rhs.ij);
}
Complex operator * (const Complex &rhs) const{
return Complex(ii * rhs.ii - ij * rhs.ij, ii * rhs.ij + ij * rhs.ii);
}
}; Complex a1[maxn], a2[maxn]; void fft(Complex *src, int len, int rev){
//len is power of 2
//rev == 1::dft rev == -1::idft
for(int i = , j = ; i < len; i++){
for(int k = len >> ; k > (j ^= k); k >>= ) ;
if(i < j) swap(src[i], src[j]);
}
for(int i = ; i <= len; i <<= ){
Complex wi(cos( * pi * rev / i), sin( * pi * rev / i));
//(wi)^i = 1
for(int j = ; j < len; j += i){
//using iteration insetad of recursion
Complex w(1.0, 0.0);
//w = (wi)^0
for(int k = j; k < j + i / ; k++){
Complex tem = w * src[k + i / ];
src[k + i / ] = src[k] - tem;
src[k] = src[k] + tem;
w = w * wi;
}
}
}
if(rev == -){
for(int i = ; i < len; i++) src[i].ii = (src[i].ii / len + eps);
}
} void multi(int *src1, int *src2, int len){
for(int i = ; i < len; i++){
a1[i].ii = a1[i].ij = a2[i].ii = a2[i].ij = ;
if(i < q){
a1[i].ii = (double)src1[i];
a2[i].ii = (double)src2[i];
}
}
fft(a1, len, ), fft(a2, len, );
for(int i = ; i < len; i++) a1[i] = a1[i] * a2[i];
fft(a1, len, -);
for(int i = ; i < len; i++) g[i] = (int)((ll)(a1[i].ii + eps) % mod);
for(int i = * q - ; i >= q; i--){
//this is because for the fisrt row in matrix A,
//which satisfies ths (f(n + q),...,f(n))T = A((f(n + q - 1),...,f(n - 1))T)
//only two elements are nonzero integers
g[i - q] = (g[i - q] + g[i] * b) % mod;
g[i - p] = (g[i - p] + g[i] * a) % mod;
}
memcpy(src1, g, sizeof(int) * q);
} int tmp[maxn], ans[maxn]; int main(){
//freopen("in.txt", "r", stdin);
while(~scanf("%d%d%d%d%d", &n, &a, &b, &p, &q)){
a %= mod, b %= mod;
f[] = ;
for(int i = ; i < q; i++){
f[i] = i < p ? a + b : a * f[i - p] + b;
f[i] %= mod;
}
if(n < q){
printf("%d\n", f[n]);
continue;
}
size = ;
while(size <= (q - ) * ) size <<= ;
memset(tmp, , sizeof tmp);
memset(ans, , sizeof ans);
ans[] = tmp[] = ;
while(n){
if(n & ) multi(ans, tmp, size);
multi(tmp, tmp, size);
n >>= ;
}
int res = ;
for(int i = ; i < q; i++){
res = (res + ans[i] * f[i]) % mod;
}
printf("%d\n", res);
}
return ;
}