超大整数运算算法——为RSA加密算法提供运算工具

时间:2021-04-10 23:54:46

/* program: Large integer operations
 * Made by:  Daiyyr
 * date:  2013/07/09

* This software is licensed under the terms of the GNU General Public
 * License version 2, as published by the Free Software Foundation, and
 * may be copied, distributed, and modified under those terms.
 */

//开源代码,引用请务必遵守GNU规则

//未完待续,尚缺 乘方、除法、取模运算

#include <string.h>
#include <stdlib.h>
#include <stdio.h>
#include <math.h>

//#define DEBUG

#ifdef DEBUG
    #define LOG(format, args...)   do { printf(format, ## args); \
                                   printf("\n");}\
                                   while(0)
                                   
    #define DBG() printk("[%s]:%d => \n",__FUNCTION__,__LINE__)
#else
    #define LOG(format, args...);
    #define DBG()
#endif

#define LOGE(format, args...)   do { printf(format, ## args); \
                                 printf("\n");}\
                                 while(0)

#define LENGTH 1000   //max digit of this long number

struct charnumber{
    char sign;          //0 for positive, 1 for negative
    int length;
    unsigned char* array;//low index of this array is for high digit of the number
    unsigned char* pointerforfree;
};

int char2interger(unsigned char onebyte){
    if (onebyte < 10)
        return onebyte;
    else
        return ((onebyte>>4)*10)+(onebyte&15);
}

struct charnumber* interger2charnumber(signed long int number){
    unsigned char charnumber[LENGTH/2];
    int i=0;
    struct charnumber *target;
    LOG("num: %ld", number);
    target = malloc(sizeof(struct charnumber));
    if (number<0){
        target->sign=1;
        number = -number;
        LOG("num:%ld", number);
    }
    else{
        target->sign=0;
    }
    while(number!=0){
        charnumber[i]=number%10;
        LOG("i:%d, num:%ld, char:%d", i, number, charnumber[i]);
        number/=10;
        charnumber[i]|=number%10<<4;
        LOG("i:%d, num:%ld, char:%d", i, number, charnumber[i]);
        number/=10;
        i++;
    }
    target->length=i;
    target->array = malloc(target->length);
    target->pointerforfree = target->array;
    for(i--;i>(target->length-1)/2;i--){   //Reverse
        LOG("length=%d, i=%d, left=%d, right=%d", target->length, i, charnumber[target->length-i-1], charnumber[i]);
        charnumber[i] = charnumber[i] + charnumber[target->length-i-1];
        charnumber[target->length-i-1] = charnumber[i] - charnumber[target->length-i-1];
        charnumber[i] = charnumber[i] - charnumber[target->length-i-1];
    }
    memcpy(target->array, charnumber, target->length);    
    for(i=0;i<target->length;i++)
        LOG("i:%d:%d", i, char2interger(target->array[i]));    
    LOG("~~length:%d", target->length, target->length);
    return target;
}
 
char* charnumber2string(struct charnumber* charnumber){
    int length = 2*charnumber->length+1;
    char* showchar = malloc(length);
    int i;

LOG("string signed:%d", charnumber->sign);
    if(charnumber->sign!=0)
        showchar[0]='-';
    else
        showchar[0]='+';
    for(i=0; i<charnumber->length; i++){
        showchar[2*i+1]='0'+(charnumber->array[i]>>4);
        showchar[2*i+2]='0'+(charnumber->array[i]&15);
        LOG("%d:  %x, %c%c",i, charnumber->array[i], showchar[2*i+1],showchar[2*i+2]);
    }
    showchar[length]='\0';
    if (!strcmp(showchar, "+"))
        strcpy(showchar, "+0");
    else if(showchar[1]=='0')
        showchar++, showchar[0]=charnumber->sign ? '-' : '+';
    return showchar;
}

struct charnumber* string2charnumber(char* string){
    if(!(string[0]=='-' || string[0]=='+' || (string[0]>='0' && string[0]<='9'))){
        printf("error: in string2charnumber, invalid input number!\n");
        return NULL;
    }
    struct charnumber *number;
    char* temp = malloc(strlen(string));
    int i, odd;

strcpy(temp, string);
    number = malloc(sizeof(struct charnumber));
    if(temp[0]=='-'){
        number->sign = 1;
        temp++;        
    }
    else{
        number->sign = 0;
    }
    odd = strlen(temp)%2;
    LOG("strlen(temp): %d, odd=%d\n", strlen(temp), odd);
    number->length = strlen(temp)/2 + odd;
    number->array = malloc(number->length);
    number->pointerforfree = number->array;
    for(i=number->length-1; i>=0; i--){
        if(i){
            if(temp[i*2+1-odd] > '9' || temp[i*2+1-odd] < '0' || temp[i*2-odd] > '9' || temp[i*2-odd] < '0'){
                printf("error! in string2charnumber, invalid input number!\n");
                LOG("temp[%d]=%x, temp[%d]=%x",i*2+1-odd, temp[i*2+1-odd], i*2-odd,  temp[i*2-odd]);
                return NULL;
            }
            number->array[i]=(temp[i*2+1-odd] - '0') + ((temp[i*2-odd] - '0') << 4);
            LOG("number->array[%d]=%x", i, number->array[i]);
        }
        else{
            number->array[0]=odd ? temp[0] - '0'   :  (temp[1] - '0') + ((temp[0] - '0') << 4);
            LOG("temp[1]:%d, temp[0]:%d,  '0':%d,  number->array[%d]=%x",temp[1],temp[0],  '0', i, number->array[i]);
        }
    }
    return number;
}

struct charnumber* minusCharNumber(struct charnumber* charnumber1, struct charnumber* charnumber2);

struct charnumber* plusCharNumber(struct charnumber* charnumber1, struct charnumber* charnumber2){
    struct charnumber *temp1, *temp2, *sum;

if(!!charnumber1->sign != !!charnumber2->sign){
        temp1 = malloc(sizeof(struct charnumber));
        temp2 = malloc(sizeof(struct charnumber));
        memcpy(temp1, charnumber1, sizeof(struct charnumber));
        memcpy(temp2, charnumber2, sizeof(struct charnumber));
        temp2->sign = !temp2->sign;
        sum = minusCharNumber(temp1, temp2);
        free(temp1->pointerforfree), free(temp2->pointerforfree), free(temp1), free(temp2);
        return sum;
    }
    int i, carry = 0;

if(charnumber1->length >= charnumber2->length){
        temp1 = charnumber1;
        temp2 = charnumber2;
    }
    else{
        temp1 = charnumber2;
        temp2 = charnumber1;
    }
    sum = malloc(sizeof(struct charnumber));
    sum->sign = temp1->sign;
    LOG("signed:%d", sum->sign);
    sum->array = malloc(temp1->length+1);
    sum->pointerforfree = sum->array;
    for(i=temp1->length; i>temp1->length-temp2->length; i--){       //遍历较小数的位数
        sum->array[i]=(carry + temp1->array[i-1]&15) + (temp2->array[i-(temp1->length+1-temp2->length)]&15);
        carry = 0;
        if(sum->array[i]>9){
            sum->array[i] = 6 + sum->array[i];
        }
        sum->array[i]=sum->array[i] + (temp1->array[i-1]>>4<<4) + (temp2->array[i-(temp1->length+1-temp2->length)]>>4<<4);
        if(sum->array[i]>159)
            sum->array[i] -= 160, carry = 1;
        LOG("i: %d, 0x%x, 1:0x%x, 2:0x%x", i, sum->array[i], temp1->array[i-1], temp2->array[i-(temp1->length+1-temp2->length)]);
    }
    while(carry && i>=1){   //在遍历完较小数的位数后,处理可能的进位
        sum->array[i] = temp1->array[i-1] + 1;
        if((sum->array[i]&15) == 10){
            LOG("carry1.0, %x", sum->array[i]);
            sum->array[i]=(sum->array[i]&240) + 16;
            LOG("carry1.1, %x", sum->array[i]);
            if(sum->array[i] == 160){
                sum->array[i]=0;
                carry=1;
                LOG("carry2");
            }
            else
                carry=0;
        }
        else
            carry=0;
        i--;
    }
    if(carry){
        sum->array[0]=1;
        sum->length = temp1->length+1;
    }
    else{
        sum->array++;
        memcpy(sum->array, temp1->array, i);
        sum->length = temp1->length;
    }
    return sum;
}

struct charnumber* multiplyCharNumberAndInt(struct charnumber* charnumber, int n){
    long i;
    struct charnumber *result, *temp;
    result = interger2charnumber(0);
    for(i=0; i<n; i++){
        temp = result;
        result = plusCharNumber(result, charnumber);
        free(temp->pointerforfree);
        free(temp);
    }
    
    if(charnumber->sign ^ (n<0) )
        result->sign = 1;
    else
        result->sign = 0;
    
    return result;
}

int mypow(int x, int y){
    int i, sum=1;
    for(i=0; i<y; i++){
        sum=sum*x;
    }
    return sum;
}

struct charnumber* multiplyCharNumber(struct charnumber* charnumber1, struct charnumber* charnumber2){
    int i, orisign1, orisign2;
    struct charnumber* result, *temp1, *temp2, *tempforfree1, *tempforfree2, *tempforfree3;
    
    orisign1 = charnumber1->sign;
    orisign2 = charnumber2->sign;
    charnumber1->sign = 0;
    charnumber2->sign = 0;
    
    if(charnumber1->length >= charnumber2->length){
        temp1 = charnumber1;
        temp2 = charnumber2;
    }
    else{
        temp1 = charnumber2;
        temp2 = charnumber1;
    }
    result = interger2charnumber(0);
    LOG("t1length:%d, t2length:%d", temp1->length, temp2->length);
    for(i=0; i<temp2->length; i++){
        tempforfree1 = multiplyCharNumberAndInt(temp1, temp2->array[i]&15);
        tempforfree2 = multiplyCharNumberAndInt(tempforfree1, mypow(10, (temp2->length-i)*2-2));
        tempforfree3 = result;
        result = plusCharNumber(tempforfree3, tempforfree2);
        free(tempforfree1->pointerforfree), free(tempforfree2->pointerforfree), free(tempforfree3->pointerforfree), free(tempforfree1), free(tempforfree2), free(tempforfree3);
        LOG("multip1 i=%d, t1=0x%x, t2=0x%x, t3=0x%x, res=0x%x", i, tempforfree1->array[i], tempforfree2->array[i], tempforfree3->array[i], result->array[i]);
        tempforfree1 = multiplyCharNumberAndInt(temp1, temp2->array[i]>>4);
        tempforfree2 = multiplyCharNumberAndInt(tempforfree1, mypow(10, (temp2->length-i)*2-1));
        LOG("tempforfree1:%s, tempforfree2:%s",charnumber2string(tempforfree1), charnumber2string(tempforfree2));
        tempforfree3 = result;
        result = plusCharNumber(tempforfree3, tempforfree2);
        free(tempforfree1->pointerforfree), free(tempforfree2->pointerforfree), free(tempforfree3->pointerforfree), free(tempforfree1), free(tempforfree2), free(tempforfree3);
    }
    charnumber1->sign = orisign1;
    charnumber2->sign = orisign2;
    
    if(charnumber1->sign ^ charnumber2->sign ){
        result->sign = 1;
    }
    else
        result->sign = 0;
    return result;
}

struct charnumber* minusCharNumber(struct charnumber* charnumber1, struct charnumber* charnumber2){    
    struct charnumber *temp1, *temp2, *difference;
    
    if(!!charnumber1->sign != !!charnumber2->sign){
        temp1 = malloc(sizeof(struct charnumber));
        temp2 = malloc(sizeof(struct charnumber));
        memcpy(temp1, charnumber1, sizeof(struct charnumber));
        memcpy(temp2, charnumber2, sizeof(struct charnumber));
        temp2->sign = !temp2->sign;
        difference = plusCharNumber(temp1, temp2);
        free(temp1->pointerforfree), free(temp2->pointerforfree), free(temp1), free(temp2);
        return difference;
    }

int i, borrow = 0;
    difference = malloc(sizeof(struct charnumber));
    if(charnumber1->length > charnumber2->length){
        temp1 = charnumber1;
        temp2 = charnumber2;
        difference->sign = temp1->sign;
    }
    else if(charnumber1->length < charnumber2->length){
        temp1 = charnumber2;
        temp2 = charnumber1;
        difference->sign = -temp1->sign;
    }
    else{
        for(i=0; i<charnumber1->length; i++){
            if(charnumber1->array[i] > charnumber2->array[i]){
                temp1 = charnumber1;
                temp2 = charnumber2;
                difference->sign = temp1->sign;
                break;
            }
            else if(charnumber1->array[i] < charnumber2->array[i]){
                temp1 = charnumber2;
                temp2 = charnumber1;
                difference->sign = -temp1->sign;
                break;
            }
        }
        if(i == charnumber1->length)
            return interger2charnumber(0);
    }
    difference->array = malloc(temp1->length);
    difference->pointerforfree = difference->array;
    for(i=temp1->length-1; i>=temp1->length-temp2->length; i--){       //遍历较小数的位数
        difference->array[i]=(borrow + temp1->array[i]&15) - (temp2->array[i-(temp1->length-temp2->length)]&15);
        if(difference->array[i]>10){
            difference->array[i] = difference->array[i]-6;
        }
        LOG("minus:i: %d, 0x%x, 1:0x%x, 2:0x%x", i, difference->array[i], temp1->array[i], temp2->array[i-(temp1->length-temp2->length)]);
        borrow = 0;

difference->array[i]=difference->array[i] + (temp1->array[i]>>4<<4) - (temp2->array[i-(temp1->length-temp2->length)]>>4<<4);
        if(difference->array[i]>160)
            difference->array[i] += 160, borrow = -1;
        LOG("minus:i: %d, 0x%x, 1:0x%x, 2:0x%x", i, difference->array[i], temp1->array[i], temp2->array[i-(temp1->length-temp2->length)]);
    }
    if(!borrow && temp1->length - temp2->length == 1)
        memcpy(difference->array, temp1->array, i+1);
    while(borrow && i>=0){   //在遍历完较小数的位数后,处理可能的借位
        difference->array[i] = temp1->array[i] - 1;
        if((difference->array[i]&15) == 15){
            LOG("carry1.0, %x", difference->array[i]);
            difference->array[i]=(difference->array[i]&240) - 7;
            LOG("carry1.1, %x", difference->array[i]);
            if(difference->array[i] == 249){//0xf9
                difference->array[i]=153;//0x99
                borrow=-1;
                LOG("carry2");
            }
            else
                borrow=0;
        }
        else
            borrow=0;
        i--;
    }
    LOG("BEFORE WHILE, i=%d", i);
    if(i>0){
        LOG("memcpy");
        memcpy(difference->array, temp1->array, i+1);
        difference->length = temp1->length;
    }    
    else{
        i=0;
        while(difference->array[0]==0){
            difference->array++;
            i++;
        }
        difference->length = temp1->length-i;
    }
    return difference;
}

//struct charnumber* powerCharNumber(struct charnumber* charnumber1, struct charnumber* pow){
//    struct charnumber* char1, sum=1;
//    int i;

//    char1 = malloc(struct charnumber);
//    char1->length = charnumber1->length;
//    char1->sign = charnumber1->sign;
//    char1->array = malloc(char1->length);
//    char1->pointerforfree = char1->array;
//    memcpy(char1->array, charnumber1->array, char1->length);
//    while(char1->length != 0){
//        if(char1->array[char1->length-1]&15 == 0){
//            if(char1->array[char1->length-1] == 0){
//                char1->array[char1->length-2]
//            }
//        }
//        sum=multiplyCharNumber(sum, charnumber1);
//        char1->length--;
//    }
//    return sum;
//}