| // SPDX-License-Identifier: BSD-2-Clause |
| /* |
| * Copyright (c) 2018, Linaro Limited |
| */ |
| |
| #include <crypto/crypto.h> |
| #include <kernel/panic.h> |
| #include <mbedtls/bignum.h> |
| #include <mempool.h> |
| #include <stdlib.h> |
| #include <string.h> |
| #include <tomcrypt_private.h> |
| #include <tomcrypt_mp.h> |
| #include <util.h> |
| |
| #if defined(_CFG_CORE_LTC_PAGER) |
| #include <mm/core_mmu.h> |
| #include <mm/tee_pager.h> |
| #endif |
| |
| /* Size needed for xtest to pass reliably on both ARM32 and ARM64 */ |
| #define MPI_MEMPOOL_SIZE (42 * 1024) |
| |
| /* From mbedtls/library/bignum.c */ |
| #define ciL (sizeof(mbedtls_mpi_uint)) /* chars in limb */ |
| #define biL (ciL << 3) /* bits in limb */ |
| #define BITS_TO_LIMBS(i) ((i) / biL + ((i) % biL != 0)) |
| |
| #if defined(_CFG_CORE_LTC_PAGER) |
| /* allocate pageable_zi vmem for mp scratch memory pool */ |
| static struct mempool *get_mp_scratch_memory_pool(void) |
| { |
| size_t size; |
| void *data; |
| |
| size = ROUNDUP(MPI_MEMPOOL_SIZE, SMALL_PAGE_SIZE); |
| data = tee_pager_alloc(size); |
| if (!data) |
| panic(); |
| |
| return mempool_alloc_pool(data, size, tee_pager_release_phys); |
| } |
| #else /* _CFG_CORE_LTC_PAGER */ |
| static struct mempool *get_mp_scratch_memory_pool(void) |
| { |
| static uint8_t data[MPI_MEMPOOL_SIZE] __aligned(MEMPOOL_ALIGN); |
| |
| return mempool_alloc_pool(data, sizeof(data), NULL); |
| } |
| #endif |
| |
| void init_mp_tomcrypt(void) |
| { |
| struct mempool *p = get_mp_scratch_memory_pool(); |
| |
| if (!p) |
| panic(); |
| mbedtls_mpi_mempool = p; |
| assert(!mempool_default); |
| mempool_default = p; |
| } |
| |
| static int init(void **a) |
| { |
| mbedtls_mpi *bn = mempool_alloc(mbedtls_mpi_mempool, sizeof(*bn)); |
| |
| if (!bn) |
| return CRYPT_MEM; |
| |
| mbedtls_mpi_init_mempool(bn); |
| *a = bn; |
| return CRYPT_OK; |
| } |
| |
| static int init_size(int size_bits __unused, void **a) |
| { |
| return init(a); |
| } |
| |
| static void deinit(void *a) |
| { |
| mbedtls_mpi_free((mbedtls_mpi *)a); |
| mempool_free(mbedtls_mpi_mempool, a); |
| } |
| |
| static int neg(void *a, void *b) |
| { |
| if (mbedtls_mpi_copy(b, a)) |
| return CRYPT_MEM; |
| ((mbedtls_mpi *)b)->s *= -1; |
| return CRYPT_OK; |
| } |
| |
| static int copy(void *a, void *b) |
| { |
| if (mbedtls_mpi_copy(b, a)) |
| return CRYPT_MEM; |
| return CRYPT_OK; |
| } |
| |
| static int init_copy(void **a, void *b) |
| { |
| if (init(a) != CRYPT_OK) { |
| return CRYPT_MEM; |
| } |
| return copy(b, *a); |
| } |
| |
| /* ---- trivial ---- */ |
| static int set_int(void *a, ltc_mp_digit b) |
| { |
| uint32_t b32 = b; |
| |
| if (b32 != b) |
| return CRYPT_INVALID_ARG; |
| |
| mbedtls_mpi_uint p = b32; |
| mbedtls_mpi bn = { .s = 1, .n = 1, .p = &p }; |
| |
| if (mbedtls_mpi_copy(a, &bn)) |
| return CRYPT_MEM; |
| return CRYPT_OK; |
| } |
| |
| static unsigned long get_int(void *a) |
| { |
| mbedtls_mpi *bn = a; |
| |
| if (!bn->n) |
| return 0; |
| |
| return bn->p[bn->n - 1]; |
| } |
| |
| static ltc_mp_digit get_digit(void *a, int n) |
| { |
| mbedtls_mpi *bn = a; |
| |
| COMPILE_TIME_ASSERT(sizeof(ltc_mp_digit) >= sizeof(mbedtls_mpi_uint)); |
| |
| if (n < 0 || (size_t)n >= bn->n) |
| return 0; |
| |
| return bn->p[n]; |
| } |
| |
| static int get_digit_count(void *a) |
| { |
| return ROUNDUP(mbedtls_mpi_size(a), sizeof(mbedtls_mpi_uint)) / |
| sizeof(mbedtls_mpi_uint); |
| } |
| |
| static int compare(void *a, void *b) |
| { |
| int ret = mbedtls_mpi_cmp_mpi(a, b); |
| |
| if (ret < 0) |
| return LTC_MP_LT; |
| |
| if (ret > 0) |
| return LTC_MP_GT; |
| |
| return LTC_MP_EQ; |
| } |
| |
| static int compare_d(void *a, ltc_mp_digit b) |
| { |
| unsigned long v = b; |
| unsigned int shift = 31; |
| uint32_t mask = BIT(shift) - 1; |
| mbedtls_mpi bn; |
| |
| mbedtls_mpi_init_mempool(&bn); |
| while (true) { |
| mbedtls_mpi_add_int(&bn, &bn, v & mask); |
| v >>= shift; |
| if (!v) |
| break; |
| mbedtls_mpi_shift_l(&bn, shift); |
| } |
| |
| int ret = compare(a, &bn); |
| |
| mbedtls_mpi_free(&bn); |
| |
| return ret; |
| } |
| |
| static int count_bits(void *a) |
| { |
| return mbedtls_mpi_bitlen(a); |
| } |
| |
| static int count_lsb_bits(void *a) |
| { |
| return mbedtls_mpi_lsb(a); |
| } |
| |
| |
| static int twoexpt(void *a, int n) |
| { |
| if (mbedtls_mpi_set_bit(a, n, 1)) |
| return CRYPT_MEM; |
| |
| return CRYPT_OK; |
| } |
| |
| /* ---- conversions ---- */ |
| |
| /* read ascii string */ |
| static int read_radix(void *a, const char *b, int radix) |
| { |
| int res = mbedtls_mpi_read_string(a, radix, b); |
| |
| if (res == MBEDTLS_ERR_MPI_ALLOC_FAILED) |
| return CRYPT_MEM; |
| if (res) |
| return CRYPT_ERROR; |
| |
| return CRYPT_OK; |
| } |
| |
| /* write one */ |
| static int write_radix(void *a, char *b, int radix) |
| { |
| size_t ol = SIZE_MAX; |
| int res = mbedtls_mpi_write_string(a, radix, b, ol, &ol); |
| |
| if (res == MBEDTLS_ERR_MPI_ALLOC_FAILED) |
| return CRYPT_MEM; |
| if (res) |
| return CRYPT_ERROR; |
| |
| return CRYPT_OK; |
| } |
| |
| /* get size as unsigned char string */ |
| static unsigned long unsigned_size(void *a) |
| { |
| return mbedtls_mpi_size(a); |
| } |
| |
| /* store */ |
| static int unsigned_write(void *a, unsigned char *b) |
| { |
| int res = mbedtls_mpi_write_binary(a, b, unsigned_size(a)); |
| |
| if (res == MBEDTLS_ERR_MPI_ALLOC_FAILED) |
| return CRYPT_MEM; |
| if (res) |
| return CRYPT_ERROR; |
| |
| return CRYPT_OK; |
| } |
| |
| /* read */ |
| static int unsigned_read(void *a, unsigned char *b, unsigned long len) |
| { |
| int res = mbedtls_mpi_read_binary(a, b, len); |
| |
| if (res == MBEDTLS_ERR_MPI_ALLOC_FAILED) |
| return CRYPT_MEM; |
| if (res) |
| return CRYPT_ERROR; |
| |
| return CRYPT_OK; |
| } |
| |
| /* add */ |
| static int add(void *a, void *b, void *c) |
| { |
| if (mbedtls_mpi_add_mpi(c, a, b)) |
| return CRYPT_MEM; |
| |
| return CRYPT_OK; |
| } |
| |
| static int addi(void *a, ltc_mp_digit b, void *c) |
| { |
| uint32_t b32 = b; |
| |
| if (b32 != b) |
| return CRYPT_INVALID_ARG; |
| |
| mbedtls_mpi_uint p = b32; |
| mbedtls_mpi bn = { .s = 1, .n = 1, .p = &p }; |
| |
| return add(a, &bn, c); |
| } |
| |
| /* sub */ |
| static int sub(void *a, void *b, void *c) |
| { |
| if (mbedtls_mpi_sub_mpi(c, a, b)) |
| return CRYPT_MEM; |
| |
| return CRYPT_OK; |
| } |
| |
| static int subi(void *a, ltc_mp_digit b, void *c) |
| { |
| uint32_t b32 = b; |
| |
| if (b32 != b) |
| return CRYPT_INVALID_ARG; |
| |
| mbedtls_mpi_uint p = b32; |
| mbedtls_mpi bn = { .s = 1, .n = 1, .p = &p }; |
| |
| return sub(a, &bn, c); |
| } |
| |
| /* mul */ |
| static int mul(void *a, void *b, void *c) |
| { |
| if (mbedtls_mpi_mul_mpi(c, a, b)) |
| return CRYPT_MEM; |
| |
| return CRYPT_OK; |
| } |
| |
| static int muli(void *a, ltc_mp_digit b, void *c) |
| { |
| if (b > (unsigned long) UINT32_MAX) |
| return CRYPT_INVALID_ARG; |
| |
| if (mbedtls_mpi_mul_int(c, a, b)) |
| return CRYPT_MEM; |
| |
| return CRYPT_OK; |
| } |
| |
| /* sqr */ |
| static int sqr(void *a, void *b) |
| { |
| return mul(a, a, b); |
| } |
| |
| /* div */ |
| static int divide(void *a, void *b, void *c, void *d) |
| { |
| int res = mbedtls_mpi_div_mpi(c, d, a, b); |
| |
| if (res == MBEDTLS_ERR_MPI_ALLOC_FAILED) |
| return CRYPT_MEM; |
| if (res) |
| return CRYPT_ERROR; |
| |
| return CRYPT_OK; |
| } |
| |
| static int div_2(void *a, void *b) |
| { |
| if (mbedtls_mpi_copy(b, a)) |
| return CRYPT_MEM; |
| |
| if (mbedtls_mpi_shift_r(b, 1)) |
| return CRYPT_MEM; |
| |
| return CRYPT_OK; |
| } |
| |
| /* modi */ |
| static int modi(void *a, ltc_mp_digit b, ltc_mp_digit *c) |
| { |
| mbedtls_mpi bn_b; |
| mbedtls_mpi bn_c; |
| int res = 0; |
| |
| mbedtls_mpi_init_mempool(&bn_b); |
| mbedtls_mpi_init_mempool(&bn_c); |
| |
| res = set_int(&bn_b, b); |
| if (res) |
| return res; |
| |
| res = mbedtls_mpi_mod_mpi(&bn_c, &bn_b, a); |
| if (!res) |
| *c = get_int(&bn_c); |
| |
| mbedtls_mpi_free(&bn_b); |
| mbedtls_mpi_free(&bn_c); |
| |
| if (res) |
| return CRYPT_MEM; |
| |
| return CRYPT_OK; |
| } |
| |
| /* gcd */ |
| static int gcd(void *a, void *b, void *c) |
| { |
| if (mbedtls_mpi_gcd(c, a, b)) |
| return CRYPT_MEM; |
| |
| return CRYPT_OK; |
| } |
| |
| /* lcm */ |
| static int lcm(void *a, void *b, void *c) |
| { |
| int res = CRYPT_MEM; |
| mbedtls_mpi tmp; |
| |
| mbedtls_mpi_init_mempool(&tmp); |
| if (mbedtls_mpi_mul_mpi(&tmp, a, b)) |
| goto out; |
| |
| if (mbedtls_mpi_gcd(c, a, b)) |
| goto out; |
| |
| /* We use the following equality: gcd(a, b) * lcm(a, b) = a * b */ |
| res = divide(&tmp, c, c, NULL); |
| out: |
| mbedtls_mpi_free(&tmp); |
| return res; |
| } |
| |
| static int mod(void *a, void *b, void *c) |
| { |
| int res = mbedtls_mpi_mod_mpi(c, a, b); |
| |
| if (res == MBEDTLS_ERR_MPI_ALLOC_FAILED) |
| return CRYPT_MEM; |
| if (res) |
| return CRYPT_ERROR; |
| |
| return CRYPT_OK; |
| } |
| |
| static int mulmod(void *a, void *b, void *c, void *d) |
| { |
| int res; |
| mbedtls_mpi ta; |
| mbedtls_mpi tb; |
| |
| mbedtls_mpi_init_mempool(&ta); |
| mbedtls_mpi_init_mempool(&tb); |
| |
| res = mod(a, c, &ta); |
| if (res) |
| goto out; |
| res = mod(b, c, &tb); |
| if (res) |
| goto out; |
| res = mul(&ta, &tb, d); |
| if (res) |
| goto out; |
| res = mod(d, c, d); |
| out: |
| mbedtls_mpi_free(&ta); |
| mbedtls_mpi_free(&tb); |
| return res; |
| } |
| |
| static int sqrmod(void *a, void *b, void *c) |
| { |
| return mulmod(a, a, b, c); |
| } |
| |
| /* invmod */ |
| static int invmod(void *a, void *b, void *c) |
| { |
| int res = mbedtls_mpi_inv_mod(c, a, b); |
| |
| if (res == MBEDTLS_ERR_MPI_ALLOC_FAILED) |
| return CRYPT_MEM; |
| if (res) |
| return CRYPT_ERROR; |
| |
| return CRYPT_OK; |
| } |
| |
| |
| /* setup */ |
| static int montgomery_setup(void *a, void **b) |
| { |
| *b = malloc(sizeof(mbedtls_mpi_uint)); |
| if (!*b) |
| return CRYPT_MEM; |
| |
| mbedtls_mpi_montg_init(*b, a); |
| |
| return CRYPT_OK; |
| } |
| |
| /* get normalization value */ |
| static int montgomery_normalization(void *a, void *b) |
| { |
| size_t c = ROUNDUP(mbedtls_mpi_size(b), sizeof(mbedtls_mpi_uint)) * 8; |
| |
| if (mbedtls_mpi_lset(a, 1)) |
| return CRYPT_MEM; |
| if (mbedtls_mpi_shift_l(a, c)) |
| return CRYPT_MEM; |
| if (mbedtls_mpi_mod_mpi(a, a, b)) |
| return CRYPT_MEM; |
| |
| return CRYPT_OK; |
| } |
| |
| /* reduce */ |
| static int montgomery_reduce(void *a, void *b, void *c) |
| { |
| mbedtls_mpi A; |
| mbedtls_mpi *N = b; |
| mbedtls_mpi_uint *mm = c; |
| mbedtls_mpi T; |
| int ret = CRYPT_MEM; |
| |
| mbedtls_mpi_init_mempool(&T); |
| mbedtls_mpi_init_mempool(&A); |
| |
| if (mbedtls_mpi_grow(&T, (N->n + 1) * 2)) |
| goto out; |
| |
| if (mbedtls_mpi_cmp_mpi(a, N) > 0) { |
| if (mbedtls_mpi_mod_mpi(&A, a, N)) |
| goto out; |
| } else { |
| if (mbedtls_mpi_copy(&A, a)) |
| goto out; |
| } |
| |
| if (mbedtls_mpi_grow(&A, N->n + 1)) |
| goto out; |
| |
| if (mbedtls_mpi_montred(&A, N, *mm, &T)) |
| goto out; |
| |
| if (mbedtls_mpi_copy(a, &A)) |
| goto out; |
| |
| ret = CRYPT_OK; |
| out: |
| mbedtls_mpi_free(&A); |
| mbedtls_mpi_free(&T); |
| |
| return ret; |
| } |
| |
| /* clean up */ |
| static void montgomery_deinit(void *a) |
| { |
| free(a); |
| } |
| |
| /* |
| * This function calculates: |
| * d = a^b mod c |
| * |
| * @a: base |
| * @b: exponent |
| * @c: modulus |
| * @d: destination |
| */ |
| static int exptmod(void *a, void *b, void *c, void *d) |
| { |
| int res; |
| |
| if (d == a || d == b || d == c) { |
| mbedtls_mpi dest; |
| |
| mbedtls_mpi_init_mempool(&dest); |
| res = mbedtls_mpi_exp_mod(&dest, a, b, c, NULL); |
| if (!res) |
| res = mbedtls_mpi_copy(d, &dest); |
| mbedtls_mpi_free(&dest); |
| } else { |
| res = mbedtls_mpi_exp_mod(d, a, b, c, NULL); |
| } |
| |
| if (res) |
| return CRYPT_MEM; |
| else |
| return CRYPT_OK; |
| } |
| |
| static int rng_read(void *ignored __unused, unsigned char *buf, size_t blen) |
| { |
| if (crypto_rng_read(buf, blen)) |
| return MBEDTLS_ERR_MPI_FILE_IO_ERROR; |
| return 0; |
| } |
| |
| static int isprime(void *a, int b __unused, int *c) |
| { |
| int res = mbedtls_mpi_is_prime(a, rng_read, NULL); |
| |
| if (res == MBEDTLS_ERR_MPI_ALLOC_FAILED) |
| return CRYPT_MEM; |
| |
| if (res) |
| *c = LTC_MP_NO; |
| else |
| *c = LTC_MP_YES; |
| |
| return CRYPT_OK; |
| } |
| |
| static int mpa_rand(void *a, int size) |
| { |
| if (mbedtls_mpi_fill_random(a, size, rng_read, NULL)) |
| return CRYPT_MEM; |
| |
| return CRYPT_OK; |
| } |
| |
| ltc_math_descriptor ltc_mp = { |
| .name = "MPI", |
| .bits_per_digit = sizeof(mbedtls_mpi_uint) * 8, |
| |
| .init = &init, |
| .init_size = &init_size, |
| .init_copy = &init_copy, |
| .deinit = &deinit, |
| |
| .neg = &neg, |
| .copy = ©, |
| |
| .set_int = &set_int, |
| .get_int = &get_int, |
| .get_digit = &get_digit, |
| .get_digit_count = &get_digit_count, |
| .compare = &compare, |
| .compare_d = &compare_d, |
| .count_bits = &count_bits, |
| .count_lsb_bits = &count_lsb_bits, |
| .twoexpt = &twoexpt, |
| |
| .read_radix = &read_radix, |
| .write_radix = &write_radix, |
| .unsigned_size = &unsigned_size, |
| .unsigned_write = &unsigned_write, |
| .unsigned_read = &unsigned_read, |
| |
| .add = &add, |
| .addi = &addi, |
| .sub = &sub, |
| .subi = &subi, |
| .mul = &mul, |
| .muli = &muli, |
| .sqr = &sqr, |
| .mpdiv = ÷, |
| .div_2 = &div_2, |
| .modi = &modi, |
| .gcd = &gcd, |
| .lcm = &lcm, |
| |
| .mulmod = &mulmod, |
| .sqrmod = &sqrmod, |
| .invmod = &invmod, |
| |
| .montgomery_setup = &montgomery_setup, |
| .montgomery_normalization = &montgomery_normalization, |
| .montgomery_reduce = &montgomery_reduce, |
| .montgomery_deinit = &montgomery_deinit, |
| |
| .exptmod = &exptmod, |
| .isprime = &isprime, |
| |
| #ifdef LTC_MECC |
| #ifdef LTC_MECC_FP |
| .ecc_ptmul = <c_ecc_fp_mulmod, |
| #else |
| .ecc_ptmul = <c_ecc_mulmod, |
| #endif /* LTC_MECC_FP */ |
| .ecc_ptadd = <c_ecc_projective_add_point, |
| .ecc_ptdbl = <c_ecc_projective_dbl_point, |
| .ecc_map = <c_ecc_map, |
| #ifdef LTC_ECC_SHAMIR |
| #ifdef LTC_MECC_FP |
| .ecc_mul2add = <c_ecc_fp_mul2add, |
| #else |
| .ecc_mul2add = <c_ecc_mul2add, |
| #endif /* LTC_MECC_FP */ |
| #endif /* LTC_ECC_SHAMIR */ |
| #endif /* LTC_MECC */ |
| |
| #ifdef LTC_MRSA |
| .rsa_keygen = &rsa_make_key, |
| .rsa_me = &rsa_exptmod, |
| #endif |
| .rand = &mpa_rand, |
| |
| }; |
| |
| size_t crypto_bignum_num_bytes(struct bignum *a) |
| { |
| return mbedtls_mpi_size((mbedtls_mpi *)a); |
| } |
| |
| size_t crypto_bignum_num_bits(struct bignum *a) |
| { |
| return mbedtls_mpi_bitlen((mbedtls_mpi *)a); |
| } |
| |
| int32_t crypto_bignum_compare(struct bignum *a, struct bignum *b) |
| { |
| return mbedtls_mpi_cmp_mpi((mbedtls_mpi *)a, (mbedtls_mpi *)b); |
| } |
| |
| void crypto_bignum_bn2bin(const struct bignum *from, uint8_t *to) |
| { |
| mbedtls_mpi_write_binary((const mbedtls_mpi *)from, (void *)to, |
| mbedtls_mpi_size((const mbedtls_mpi *)from)); |
| } |
| |
| TEE_Result crypto_bignum_bin2bn(const uint8_t *from, size_t fromsize, |
| struct bignum *to) |
| { |
| if (mbedtls_mpi_read_binary((mbedtls_mpi *)to, (const void *)from, |
| fromsize)) |
| return TEE_ERROR_BAD_PARAMETERS; |
| return TEE_SUCCESS; |
| } |
| |
| void crypto_bignum_copy(struct bignum *to, const struct bignum *from) |
| { |
| mbedtls_mpi_copy((mbedtls_mpi *)to, (const mbedtls_mpi *)from); |
| } |
| |
| struct bignum *crypto_bignum_allocate(size_t size_bits) |
| { |
| mbedtls_mpi *bn = malloc(sizeof(*bn)); |
| |
| if (!bn) |
| return NULL; |
| |
| mbedtls_mpi_init(bn); |
| if (mbedtls_mpi_grow(bn, BITS_TO_LIMBS(size_bits))) { |
| free(bn); |
| return NULL; |
| } |
| |
| return (struct bignum *)bn; |
| } |
| |
| void crypto_bignum_free(struct bignum *s) |
| { |
| mbedtls_mpi_free((mbedtls_mpi *)s); |
| free(s); |
| } |
| |
| void crypto_bignum_clear(struct bignum *s) |
| { |
| mbedtls_mpi *bn = (mbedtls_mpi *)s; |
| |
| bn->s = 1; |
| if (bn->p) |
| memset(bn->p, 0, sizeof(*bn->p) * bn->n); |
| } |