diff --git a/components/mbedtls/port/esp_bignum.c b/components/mbedtls/port/esp_bignum.c index 4dac2b510..644277ebe 100644 --- a/components/mbedtls/port/esp_bignum.c +++ b/components/mbedtls/port/esp_bignum.c @@ -361,6 +361,18 @@ int mbedtls_mpi_exp_mod( mbedtls_mpi* Z, const mbedtls_mpi* X, const mbedtls_mpi mbedtls_mpi *Rinv; /* points to _Rinv (if not NULL) othwerwise &RR_new */ mbedtls_mpi_uint Mprime; + if (mbedtls_mpi_cmp_int(M, 0) <= 0 || (M->p[0] & 1) == 0) { + return MBEDTLS_ERR_MPI_BAD_INPUT_DATA; + } + + if (mbedtls_mpi_cmp_int(Y, 0) < 0) { + return MBEDTLS_ERR_MPI_BAD_INPUT_DATA; + } + + if (mbedtls_mpi_cmp_int(Y, 0) == 0) { + return mbedtls_mpi_lset(Z, 1); + } + if (hw_words * 32 > 4096) { return MBEDTLS_ERR_MPI_NOT_ACCEPTABLE; } @@ -394,13 +406,24 @@ int mbedtls_mpi_exp_mod( mbedtls_mpi* Z, const mbedtls_mpi* X, const mbedtls_mpi start_op(RSA_START_MODEXP_REG); /* X ^ Y may actually be shorter than M, but unlikely when used for crypto */ - MBEDTLS_MPI_CHK( mbedtls_mpi_grow(Z, m_words) ); + if ((ret = mbedtls_mpi_grow(Z, m_words)) != 0) { + esp_mpi_release_hardware(); + goto cleanup; + } wait_op_complete(RSA_START_MODEXP_REG); mem_block_to_mpi(Z, RSA_MEM_Z_BLOCK_BASE, m_words); esp_mpi_release_hardware(); + // Compensate for negative X + if (X->s == -1 && (Y->p[0] & 1) != 0) { + Z->s = -1; + MBEDTLS_MPI_CHK(mbedtls_mpi_add_mpi(Z, M, Z)); + } else { + Z->s = 1; + } + cleanup: if (_Rinv == NULL) { mbedtls_mpi_free(&Rinv_new); diff --git a/components/mbedtls/test/test_mbedtls_mpi.c b/components/mbedtls/test/test_mbedtls_mpi.c index ade0fdb6b..084fe3484 100644 --- a/components/mbedtls/test/test_mbedtls_mpi.c +++ b/components/mbedtls/test/test_mbedtls_mpi.c @@ -11,6 +11,7 @@ #include "unity.h" #include "sdkconfig.h" +#define MBEDTLS_OK 0 /* Debugging function to print an MPI number to stdout. Happens to print output that can be copy-pasted directly into a Python shell. @@ -116,11 +117,14 @@ TEST_CASE("test MPI multiplication", "[bignum]") 4096); } -static void test_bignum_modexp(const char *z_str, const char *x_str, const char *y_str, const char *m_str) +static bool test_bignum_modexp(const char *z_str, const char *x_str, const char *y_str, const char *m_str, int ret_error) { mbedtls_mpi Z, X, Y, M; char z_buf[400] = { 0 }; size_t z_buf_len = 0; + bool fail = false; + + printf("%s = (%s ^ %s) mod %s ret=%d ... ", z_str, x_str, y_str, m_str, ret_error); mbedtls_mpi_init(&Z); mbedtls_mpi_init(&X); @@ -136,38 +140,72 @@ static void test_bignum_modexp(const char *z_str, const char *x_str, const char //mbedtls_mpi_printf("M", &M); /* Z = (X ^ Y) mod M */ - TEST_ASSERT_FALSE(mbedtls_mpi_exp_mod(&Z, &X, &Y, &M, NULL)); - - mbedtls_mpi_write_string(&Z, 16, z_buf, sizeof(z_buf)-1, &z_buf_len); - TEST_ASSERT_EQUAL_STRING_MESSAGE(z_str, z_buf, "mbedtls_mpi_exp_mod incorrect"); + if (ret_error != mbedtls_mpi_exp_mod(&Z, &X, &Y, &M, NULL)) { + fail = true; + } + if (ret_error == MBEDTLS_OK) { + mbedtls_mpi_write_string(&Z, 16, z_buf, sizeof(z_buf)-1, &z_buf_len); + if (memcmp(z_str, z_buf, strlen(z_str)) != 0) { + printf("\n Expected '%s' Was '%s' \n", z_str, z_buf); + fail = true; + } + } + mbedtls_mpi_free(&Z); mbedtls_mpi_free(&X); mbedtls_mpi_free(&Y); mbedtls_mpi_free(&M); + + if (fail == true) { + printf(" FAIL\n"); + } else { + printf(" PASS\n"); + } + return fail; } TEST_CASE("test MPI modexp", "[bignum]") { - test_bignum_modexp("01000000", "1000", "2", "FFFFFFFF"); - test_bignum_modexp("014B5A90", "1234", "2", "FFFFFFF"); - test_bignum_modexp("01234321", "1111", "2", "FFFFFFFF"); - test_bignum_modexp("02", "5", "1", "3"); - test_bignum_modexp("22", "55", "1", "33"); - test_bignum_modexp("0222", "555", "1", "333"); - test_bignum_modexp("2222", "5555", "1", "3333"); - test_bignum_modexp("11", "5555", "1", "33"); + bool test_error = false; + printf("Z = (X ^ Y) mod M \n"); + // test_bignum_modexp(Z, X, Y, M, ret_error); + test_error |= test_bignum_modexp("01000000", "1000", "2", "FFFFFFFF", MBEDTLS_OK); + test_error |= test_bignum_modexp("014B5A90", "1234", "2", "FFFFFFF", MBEDTLS_OK); + test_error |= test_bignum_modexp("01234321", "1111", "2", "FFFFFFFF", MBEDTLS_OK); + test_error |= test_bignum_modexp("02", "5", "1", "3", MBEDTLS_OK); + test_error |= test_bignum_modexp("22", "55", "1", "33", MBEDTLS_OK); + test_error |= test_bignum_modexp("0222", "555", "1", "333", MBEDTLS_OK); + test_error |= test_bignum_modexp("2222", "5555", "1", "3333", MBEDTLS_OK); + test_error |= test_bignum_modexp("11", "5555", "1", "33", MBEDTLS_OK); + test_error |= test_bignum_modexp("55", "1111", "1", "77", MBEDTLS_OK); + test_error |= test_bignum_modexp("88", "1111", "2", "BB", MBEDTLS_OK); + test_error |= test_bignum_modexp("01000000", "2", "128", "FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF", MBEDTLS_OK); + test_error |= test_bignum_modexp("0ABCDEF12345", "ABCDEF12345", "1", "FFFFFFFFFFFF", MBEDTLS_OK); + test_error |= test_bignum_modexp("0ABCDE", "ABCDE", "1", "FFFFF", MBEDTLS_OK); - test_bignum_modexp("55", "1111", "1", "77"); - test_bignum_modexp("88", "1111", "2", "BB"); + test_error |= test_bignum_modexp("04", "2", "2", "9", MBEDTLS_OK); + test_error |= test_bignum_modexp("04", "2", "-2", "9", MBEDTLS_ERR_MPI_BAD_INPUT_DATA); + test_error |= test_bignum_modexp("04", "2", "2", "-9", MBEDTLS_ERR_MPI_BAD_INPUT_DATA); + test_error |= test_bignum_modexp("04", "2", "-2", "-9", MBEDTLS_ERR_MPI_BAD_INPUT_DATA); - test_bignum_modexp("01000000", "2", "128", - "FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF"); + test_error |= test_bignum_modexp("01", "2", "0", "9", MBEDTLS_OK); + test_error |= test_bignum_modexp("04", "2", "0", "0", MBEDTLS_ERR_MPI_BAD_INPUT_DATA); + test_error |= test_bignum_modexp("04", "2", "2", "0", MBEDTLS_ERR_MPI_BAD_INPUT_DATA); + test_error |= test_bignum_modexp("00", "0", "2", "9", MBEDTLS_OK); + test_error |= test_bignum_modexp("01", "0", "0", "9", MBEDTLS_OK); - /* failures below here... */ - test_bignum_modexp("0ABCDEF12345", "ABCDEF12345", "1", "FFFFFFFFFFFF"); - test_bignum_modexp("0ABCDE", "ABCDE", "1", "FFFFF"); + test_error |= test_bignum_modexp("04", "-2", "2", "9", MBEDTLS_OK); + test_error |= test_bignum_modexp("01", "-2", "0", "9", MBEDTLS_OK); + test_error |= test_bignum_modexp("07", "-2", "7", "9", MBEDTLS_OK); + test_error |= test_bignum_modexp("07", "-2", "1", "9", MBEDTLS_OK); + test_error |= test_bignum_modexp("02", "2", "1", "9", MBEDTLS_OK); + test_error |= test_bignum_modexp("01", "2", "0", "9", MBEDTLS_OK); - test_bignum_modexp("04", "2", "2", "9"); + test_error |= test_bignum_modexp("05", "5", "7", "7", MBEDTLS_OK); + test_error |= test_bignum_modexp("02", "-5", "7", "7", MBEDTLS_OK); + test_error |= test_bignum_modexp("01", "-5", "7", "3", MBEDTLS_OK); + + TEST_ASSERT_FALSE_MESSAGE(test_error, "mbedtls_mpi_exp_mod incorrect for some tests\n"); }