模板如下:
扩展版本:
求解a^k=b %p 求k,最小的k一定小于p,否则会重复,否则无解
***********************
gcd(a,p)=1时
设k=mi+v m=sqrt(p);
i,v<=m
a^v=b*(a^-m)^i %p
打表map for i=0~m-1 (a^i,i)
for i=0~m-1
check a^v存在
****************************
gcd(a,p)=t时
(A/t)^c*A^(x-c)=B/(t^c) %C/(t^c) c max
A^(x-c)=B/(t^c)*(A/t)^-c %C/(t^c)
const int NN = 99991 ; //sqrt(p)
int Hash[NN][2] ;
void insert(int id , LL vv){
LL v = vv % NN ;
while( Hash[v][0]!=-1 && Hash[v][1]!=vv){
v++ ; if(v == NN) v-=NN ;
}
if(Hash[v][0]==-1 ){
Hash[v][1] = vv ; Hash[v][0] = id ;
}
}
int find(LL vv){
LL v = vv % NN ;
while( Hash[v][0]!=-1 && Hash[v][1]!=vv){
v++ ;if(v == NN) v-=NN ;
}
if( Hash[v][0]==-1 ) return -1;
return Hash[v][0] ;
}
void ex_gcd(LL a , LL b , LL& x , LL& y){
if(b == 0){
x = 1 ; y = 0 ;
return ;
}
ex_gcd(b , a%b , x, y) ;
LL t = x ;
x = y;
y = t - a/b*y ;
}
LL baby_step(LL A, LL B , LL C){ //A^x=B %C 最小x,__gcd g++使用
LL D=1 % C ,d=0;
if(__gcd(A,C)!=1){
LL ans = 1 ;
for(LL i=0;i<=50;i++){
if(ans == B) return i ;
ans = ans * A % C ;
}
LL tmp ;
while( (tmp=__gcd(A,C)) != 1 ){
if(B % tmp) return -1 ;
d++ ;
B/=tmp ;
C/=tmp ;
D = D*A/tmp%C ;
} //D*A^(x-d)=B %C
} //printf("D=%lld A=%lld B=%lld C=%lld d=%lld\n",D,A,B,C,d);
memn(Hash);
LL M = ceil( sqrt(C*1.0) ) ;
LL rr = 1 ;
for(int i=0;i<M;i++){
insert(i, rr) ;
rr = rr * A % C ;
} //rr=A^M
LL jj,x,y;
for(int i=0;i<M;i++){
ex_gcd(D, C , x, y) ;
jj = find( (x * B % C+C)%C ) ; //printf("f %lld\n",r);
if(jj != -1){
return i*M+jj+d;
}
D = D * rr % C ;
}
return -1 ;
}
-1无解 sqrt(p)
=---------------------------------------
普通版本
//POJ 2417
//baby_step giant_step
// a^x = b (mod n) n为素数,a,b < n
// 求解上式 0 <= x < n的解
#include <cmath>
#include <cstdio>
#include <cstring>
#define MOD 76543
using namespace std;
int hs[MOD], head[MOD], next[MOD], id[MOD], top;
void insert(int x, int y)
{
int k = x % MOD;
hs[top] = x;
id[top] = y;
next[top] = head[k];
head[k] = top++;
}
int find(int x)
{
int k = x % MOD;
for (int i = head[k]; i != -1; i = next[i])
if (hs[i] == x)
return id[i];
return -1;
}
int BSGS(int a, int b, int n)
{
memset(head, -1, sizeof(head));
top = 1;
if (b == 1)
return 0;
int m = sqrt(n * 1.0), j;
long long x = 1, p = 1;
for (int i = 0; i < m; i++, p = p * a % n)
insert(p * b % n, i);
for (long long i = m; ; i += m)
{
if ((j = find(x = x * p % n)) != -1)
return i - j;
if (i > n)
break;
}
return -1;
}
int main()
{
int a, b, n;
while (~scanf("%d%d%d", &n, &a, &b))
{
int ans = BSGS(a, b, n);
if (ans == -1)
printf("no solution\n");
else
printf("%d\n", ans);
}
}