blob: 762dfe92d747839a35aaa424b4110df583dfde3f [file] [log] [blame]
// SPDX-License-Identifier: BSD-2-Clause
/*
* Copyright (c) 2014, STMicroelectronics International N.V.
*/
#include "mpa.h"
/*************************************************************
*
* HELPERS
*
*************************************************************/
/*------------------------------------------------------------
*
* These functions have ARM assembler implementations
*
*/
#if !defined(USE_ARM_ASM)
/* --------------------------------------------------------------------
* Function: __mpa_montgomery_mul_add
* Calculates dest = dest + src*w
* Dest must be big enough to hold the result
*/
void __mpa_montgomery_mul_add(mpanum dest, mpanum src, mpa_word_t w)
{
#if defined(MPA_SUPPORT_DWORD_T)
mpa_dword_t a;
mpa_word_t *ddig;
int32_t idx;
mpa_word_t carry;
mpa_word_t *dest_begin;
if (w == 0)
return;
dest_begin = dest->d;
ddig = dest->d;
carry = 0;
for (idx = 0; idx < src->size; idx++) {
a = (mpa_dword_t) *ddig +
(mpa_dword_t) src->d[idx] * (mpa_dword_t) w +
(mpa_dword_t) (carry);
*(ddig++) = (mpa_word_t) (a);
carry = (mpa_word_t) (a >> WORD_SIZE);
}
while (carry) {
a = (mpa_dword_t) (*ddig) + (mpa_dword_t) (carry);
*(ddig++) = (mpa_word_t) (a);
carry = (mpa_word_t) (a >> WORD_SIZE);
}
idx = (mpa_word_t) (ddig - dest_begin);
if (idx > dest->size)
dest->size = idx;
#else
#error write non-dword code for __mpa_montgomery_mul_add
#endif
}
/* --------------------------------------------------------------------
* Function: __mpa_montgomery_sub_ack
* Calculates dest = dest - src
* Assumption: dest >= src and both non-negative
* and dest > src.
*
*/
void __mpa_montgomery_sub_ack(mpanum dest, mpanum src)
{
#if defined(MPA_SUPPORT_DWORD_T)
mpa_word_t *ddig;
mpa_word_t carry;
int32_t idx;
mpa_dword_t a;
ddig = dest->d;
carry = 0;
idx = 0;
while (idx < src->size) {
a = (mpa_dword_t) (*ddig) - (mpa_dword_t) src->d[idx] -
(mpa_dword_t) carry;
*(ddig++) = (mpa_word_t) (a);
carry = 0 - (mpa_word_t) (a >> WORD_SIZE);
idx++;
}
/* Here we have no more digits in op2, we only need to keep on
* subtracting 0 from op1, and deal with carry */
while (carry) {
a = (mpa_dword_t) (*ddig) - (mpa_dword_t) carry;
*(ddig++) = (mpa_word_t) (a);
carry = 0 - (mpa_word_t) (a >> WORD_SIZE);
idx++;
}
if (idx >= dest->size) {
/* carry did propagate all the way, fix size */
while (dest->size > 0 && *(--ddig) == 0)
dest->size--;
}
#else
#error write non-dword code for __mpa_montgomery_sub_ack
#endif
}
#endif /* USE_ARM_ASM */
/*------------------------------------------------------------
*
* __mpa_montgomery_mul
*
* NOTE:
* Dest need to be able to hold one more word than the size of n
*
*/
void __mpa_montgomery_mul(mpanum dest, mpanum op1, mpanum op2, mpanum n,
mpa_word_t n_inv)
{
mpa_word_t u;
mpa_usize_t idx;
/* set dest to zero (with all unused digits to zero as well) */
mpa_wipe(dest);
for (idx = 0; idx < n->size; idx++) {
u = (dest->d[0] +
__mpanum_get_word(idx, op1) *
__mpanum_get_word(0, op2)) * n_inv;
__mpa_montgomery_mul_add(dest, op2,
__mpanum_get_word(idx, op1));
__mpa_montgomery_mul_add(dest, n, u);
/* Shift right one mpa_word */
dest->size--;
for (mpa_usize_t i = 0; i < dest->size; i++)
dest->d[i] = dest->d[i + 1];
*(dest->d + dest->size) = 0; /* set unused digit to zero. */
}
/* check if dest > n, if so set dest = dest - n */
if (__mpa_abs_cmp(dest, n) >= 0)
__mpa_montgomery_sub_ack(dest, n);
}
/*************************************************************
*
* LIB FUNCTIONS
*
*************************************************************/
/*
* mpa_compute_fmm_context
*/
int mpa_compute_fmm_context(const mpanum modulus,
mpanum r_modn,
mpanum r2_modn,
mpa_word_t *n_inv,
mpa_scratch_mem pool)
{
mpa_usize_t s;
int cmpresult;
mpanum tmp_n_inv;
mpanum gcd;
/* create a small mpanum on the stack */
uint32_t n_lsw_u32[MPA_NUMBASE_METADATA_SIZE_IN_U32 + ASIZE_TO_U32(1)];
mpanum n_lsw = (void *)n_lsw_u32;
/*
* compute r to be
* 1 << (__mpanum_size(Modulus) * WORD_SIZE) mod modulus
*/
s = __mpanum_size(modulus);
mpa_set_word(r_modn, 1);
mpa_shift_left(r_modn, r_modn, s * WORD_SIZE);
mpa_mod(r_modn, r_modn, modulus, pool);
/* compute r^2 mod modulus */
mpa_mul_mod(r2_modn, r_modn, r_modn, modulus, pool);
/* Compute the inverse of modulus mod 1 << WORD_SIZE */
n_lsw->alloc = 1;
n_lsw->size = 1;
n_lsw->d[0] = __mpanum_lsw(modulus);
mpa_alloc_static_temp_var(&tmp_n_inv, pool);
mpa_alloc_static_temp_var(&gcd, pool);
mpa_extended_gcd(gcd, tmp_n_inv, NULL, n_lsw,
(const mpanum)&Const_1_LShift_Base, pool);
cmpresult = mpa_cmp_short(gcd, 1);
if (cmpresult != 0)
goto cleanup;
/*
* We need the number n' (n_inv) such that
*
* R*r' - N*n' == 1
*
* and Extended_GCD gives us the solution to
*
* R*r' + N*n' == 1
*
* So if n' is negative, we just forget about the sign.
* If n' is positive, we need to subtract R to get
* the right residue class.
*/
if (__mpanum_sign(tmp_n_inv) == MPA_POS_SIGN) {
mpa_sub(tmp_n_inv, tmp_n_inv,
(const mpanum)&Const_1_LShift_Base, pool);
}
/* then take the absolute value */
*n_inv = __mpanum_lsw(tmp_n_inv);
cleanup:
mpa_free_static_temp_var(&gcd, pool);
mpa_free_static_temp_var(&tmp_n_inv, pool);
if (cmpresult != 0)
return -1;
return 0;
}
/*------------------------------------------------------------
*
* mpa_montgomery_mul
*
* wrapper that uses a temp variables for dest, since
* that need to be one word larger that n.
*/
void mpa_montgomery_mul(mpanum dest,
mpanum op1,
mpanum op2,
mpanum n, mpa_word_t n_inv, mpa_scratch_mem pool)
{
mpanum tmp_dest;
mpa_alloc_static_temp_var(&tmp_dest, pool);
__mpa_montgomery_mul(tmp_dest, op1, op2, n, n_inv);
mpa_copy(dest, tmp_dest);
mpa_free_static_temp_var(&tmp_dest, pool);
}