题意:
给出a、p、d、m 求a^x=d(mod p) 0<=x<=m 的解的个数
题解:
今天一整天的时间大部分都在调这题Orz BSGS什么的还是太不熟了
我们可以用BSGS拓展版求出最小解x 以及循环节开始的位置start start就是p被除(a,p)的次数
如果不存在解x 或x>m则输出0
如果x<start 那么输出1 因为如果解在循环节之前 在循环节中就不可能有x的解
如果start<=x<=m 答案就是x循环出现的次数=(m-x)/lon (lon是循环节长度)
这有些地方要注意的
因为题目的m<=2^63-1 而0<=x<=m 所以可能导致答案爆long long 要开unsigned long long
还有该mod的地方要mod啊 因为一个mod 我段异常了一个早上
然后因为对BSGS不熟 WA了一下午TAT 按AK大神的打法重打一遍才过
代码:
#include <cstdio>
#include <cstring>
#include <cmath>
typedef long long ll;
typedef unsigned long long ull;
const ll mo=,N=;
ll a,p,d,m,sum=,sq,hash[mo],hat[mo],pri[N],save[N],bo[N+],tot,add;
ll gcd(ll a,ll b){ return b ? gcd(b,a%b) : a; }
ll extgcd(ll &x,ll &y,ll a,ll b){
if (!b){
x=,y=;
return a;
}else{
ll res=extgcd(x,y,b,a%b),t=x;
x=y,y=t-a/b*y;
return res;
}
}
void push(ll x,ll y){
ll t=x%mo;
while (hash[t]>=){
if (hash[t]==x) return;
t=(t+)%mo;
}
hash[t]=x,hat[t]=y;
}
void makehash(){
for (ll i=,x=%p;i<sq;i++,x=x*a%p) push(x,i);
}
ll ha(ll x){
ll t=x%mo;
while (hash[t]>=){
if (hash[t]==x) return hat[t];
t=(t+)%mo;
}
return -;
}
ll mi(ll a,ll b){
ll res=;
for (;b;b>>=){
if (b&) res=res*a%p;
a=a*a%p;
}
return res;
}
ll makex(){
ll x,y,a1,b1=p,res;
for (ll i=;i<=sq;i++){
a1=mi(a,sq*i)*sum%p;
ll gc=extgcd(x,y,a1,b1),xx=b1/gc;
if (d%gc) continue;
x=(d/gc*x%xx+xx)%xx;
res=ha(x);
if (res>= && res<=p)
return i*sq+res+add;
}
return -;
}
void makepri(){
for (ll i=;i<=N;i++){
if (!bo[i]) pri[++pri[]]=i;
for (ll j=;j<=pri[] && i*pri[j]<=N;j++){
bo[i*pri[j]]=;
if (!(i%pri[j])) break;
}
}
}
bool check(ll t){ return mi(a,t)==; }
void makesave(ll x){
save[]=;
for (ll i=;i<=pri[] && x>;++i)
if (!(x%pri[i])){
save[++save[]]=pri[i];
while (!(x%pri[i])) x/=pri[i];
}
if (x>) save[++save[]]=x;
}
ll phi(ll t){
makesave(t);
ll res=t;
for (ll i=;i<=save[];i++)
res=res/save[i]*(save[i]-);
return res;
}
ll makelon(){
ll t=phi(p);
makesave(t);
for (ll i=;i<=save[];i++)
while (check(t/save[i]) && !(t%save[i])) t/=save[i];
return t;
}
int bsgs(){
int addx=,sd=d,sp=p;
for (int gc=gcd(a,p);gc>;gc=gcd(a,p)){
if (addx==sd) return add++;
if (d%gc) return -;
addx=addx*a%sp;
d/=gc,p/=gc;
sum=a/gc*sum%p;
++add;
}
sq=(ll)sqrt(p);
a%=p;
makehash();
return makex();
}
void work(){
ull res=;
add=,sum=;
memset(hash,-,sizeof(hash));
ll x=bsgs();
if (x==- || x>m){
puts("");
return ;
}else if (x<add){
puts("");
return;
}
ll lon=makelon();
res=(m-x)/lon+;
printf("%I64u\n",res);
}
int main(){
makepri();
while (scanf("%I64d%I64d%I64d%I64d",&a,&p,&d,&m)!=EOF){
++tot;
//printf("%d:",tot);
a%=p;
work();
}
}