]> git.proxmox.com Git - libtpms.git/blobdiff - src/tpm2/crypto/openssl/CryptRsa.c
tpm2: Only call EVP_PKEY_CTX_set0_rsa_oaep_label when label != NULL (OSSL 3)
[libtpms.git] / src / tpm2 / crypto / openssl / CryptRsa.c
index 7e2b103d06c00fb06b673f6214a61dfe3ac474fe..4ed04384feb0f69c0d9affec47266946a0dc4336 100644 (file)
@@ -3,7 +3,7 @@
 /*             Implementation of cryptographic primitives for RSA              */
 /*                          Written by Ken Goldman                             */
 /*                    IBM Thomas J. Watson Research Center                     */
-/*            $Id: CryptRsa.c 1519 2019-11-15 20:43:51Z kgoldman $             */
+/*            $Id: CryptRsa.c 1658 2021-01-22 23:14:01Z kgoldman $             */
 /*                                                                             */
 /*  Licenses and Notices                                                       */
 /*                                                                             */
@@ -55,7 +55,7 @@
 /*    arising in any way out of use or reliance upon this specification or any         */
 /*    information herein.                                                      */
 /*                                                                             */
-/*  (c) Copyright IBM Corp. and others, 2016 - 2019                            */
+/*  (c) Copyright IBM Corp. and others, 2016 - 2021                            */
 /*                                                                             */
 /********************************************************************************/
 
@@ -328,14 +328,14 @@ OaepEncode(
     CryptRandomGenerate(hLen, mySeed);
     DRBG_Generate(rand, mySeed, (UINT16)hLen);
     // mask = MGF1 (seed, nSize  hLen  1)
-    CryptMGF1(dbSize, mask, hashAlg, hLen, seed);
+    CryptMGF_KDF(dbSize, mask, hashAlg, hLen, seed, 0);
     // Create the masked db
     pm = mask;
     for(i = dbSize; i > 0; i--)
        *pp++ ^= *pm++;
     pp = &padded->buffer[hLen + 1];
     // Run the masked data through MGF1
-    if(CryptMGF1(hLen, &padded->buffer[1], hashAlg, dbSize, pp) != (unsigned)hLen)
+    if(CryptMGF_KDF(hLen, &padded->buffer[1], hashAlg, dbSize, pp, 0) != (unsigned)hLen)
        ERROR_RETURN(TPM_RC_VALUE);
     // Now XOR the seed to create masked seed
     pp = &padded->buffer[1];
@@ -377,8 +377,8 @@ OaepDecode(
        ERROR_RETURN(TPM_RC_VALUE);
     // Use the hash size to determine what to put through MGF1 in order
     // to recover the seedMask
-    CryptMGF1(hLen, seedMask, hashAlg, padded->size - hLen - 1,
-             &padded->buffer[hLen + 1]);
+    CryptMGF_KDF(hLen, seedMask, hashAlg, padded->size - hLen - 1,
+                &padded->buffer[hLen + 1], 0);
     // Recover the seed into seedMask
     pAssert(hLen <= sizeof(seedMask));
     pp = &padded->buffer[1];
@@ -386,7 +386,7 @@ OaepDecode(
     for(i = hLen; i > 0; i--)
        *pm++ ^= *pp++;
     // Use the seed to generate the data mask
-    CryptMGF1(padded->size - hLen - 1, mask, hashAlg, hLen, seedMask);
+    CryptMGF_KDF(padded->size - hLen - 1, mask, hashAlg, hLen, seedMask, 0);
     // Use the mask generated from seed to recover the padded data
     pp = &padded->buffer[hLen + 1];
     pm = mask;
@@ -482,7 +482,7 @@ RSAES_Decode(
     // Make sure that pSize has not gone over the end and that there are at least 8
     // bytes of pad data.
     fail = (pSize > coded->size) | fail;
-    fail = ((pSize - 2) < 8) | fail;
+    fail = ((pSize - 2) <= 8) | fail;
     if((message->size < (UINT16)(coded->size - pSize)) || fail)
        return TPM_RC_VALUE;
     message->size = coded->size - pSize;
@@ -511,6 +511,7 @@ CryptRsaPssSaltSize(
        saltSize = 0;
     return saltSize;
 }
+
 #if !USE_OPENSSL_FUNCTIONS_RSA         // libtpms added
 /* 10.2.17.4.9 PssEncode() */
 /* This function creates an encoded block of data that is the size of modulus. The function uses the
@@ -553,7 +554,7 @@ PssEncode(
     CryptDigestUpdate(&hashState, saltSize, salt);
     CryptHashEnd(&hashState, hLen, &pOut[out->size - hLen - 1]);
     // Create a mask
-    if(CryptMGF1(mLen, pOut, hashAlg, hLen, &pOut[mLen]) != mLen)
+    if(CryptMGF_KDF(mLen, pOut, hashAlg, hLen, &pOut[mLen], 0) != mLen)
        FAIL(FATAL_ERROR_INTERNAL);
     // Since this implementation uses key sizes that are all even multiples of
     // 8, just need to make sure that the most significant bit is CLEAR
@@ -609,7 +610,7 @@ PssDecode(
     // Use the hLen bytes at the end of the buffer to generate a mask
     // Doesn't start at the end which is a flag byte
     mLen = eIn->size - hLen - 1;
-    CryptMGF1(mLen, mask, hashAlg, hLen, &pe[mLen]);
+    CryptMGF_KDF(mLen, mask, hashAlg, hLen, &pe[mLen], 0);
     // Clear the MSO of the mask to make it consistent with the encoding.
     mask[0] &= 0x7F;
     pAssert(mLen <= sizeof(mask));
@@ -814,6 +815,7 @@ RSASSA_Decode(
     return retVal;
 }
 #endif                                 // libtpms added
+
 /* 10.2.17.4.13 CryptRsaSelectScheme() */
 /* This function is used by TPM2_RSA_Decrypt() and TPM2_RSA_Encrypt().  It sets up the rules to
    select a scheme between input and object default. This function assume the RSA object is
@@ -878,7 +880,7 @@ CryptRsaLoadPrivateExponent(
     TPM_RC          retVal = TPM_RC_SUCCESS;
     if(!rsaKey->attributes.privateExp)
        {
-           TEST(ALG_NULL_VALUE);
+           TEST(TPM_ALG_NULL);
            // Make sure that the bigNum used for the exponent is properly initialized
            RsaInitializeExponent(&rsaKey->privateExponent);
            // Find the second prime by division
@@ -936,7 +938,7 @@ CryptRsaEncrypt(
     TEST(scheme->scheme);
     switch(scheme->scheme)
        {
-         case ALG_NULL_VALUE:  // 'raw' encryption
+         case TPM_ALG_NULL:  // 'raw' encryption
              {
                  INT32            i;
                  INT32            dSize = dIn->size;
@@ -955,10 +957,10 @@ CryptRsaEncrypt(
                  // the modulus. If it is, then RSAEP() will catch it.
              }
              break;
-         case ALG_RSAES_VALUE:
+         case TPM_ALG_RSAES:
            retVal = RSAES_PKCS1v1_5Encode(&cOut->b, dIn, rand);
            break;
-         case ALG_OAEP_VALUE:
+         case TPM_ALG_OAEP:
            retVal = OaepEncode(&cOut->b, scheme->details.oaep.hashAlg, label, dIn,
                                rand);
            break;
@@ -1006,15 +1008,15 @@ CryptRsaDecrypt(
            // Remove padding
            switch(scheme->scheme)
                {
-                 case ALG_NULL_VALUE:
+                 case TPM_ALG_NULL:
                    if(dOut->size < cIn->size)
                        return TPM_RC_VALUE;
                    MemoryCopy2B(dOut, cIn, dOut->size);
                    break;
-                 case ALG_RSAES_VALUE:
+                 case TPM_ALG_RSAES:
                    retVal = RSAES_Decode(dOut, cIn);
                    break;
-                 case ALG_OAEP_VALUE:
+                 case TPM_ALG_OAEP:
                    retVal = OaepDecode(dOut, scheme->details.oaep.hashAlg, label, cIn);
                    break;
                  default:
@@ -1049,14 +1051,14 @@ CryptRsaSign(
     TEST(sigOut->sigAlg);
     switch(sigOut->sigAlg)
        {
-         case ALG_NULL_VALUE:
+         case TPM_ALG_NULL:
            sigOut->signature.rsapss.sig.t.size = 0;
            return TPM_RC_SUCCESS;
-         case ALG_RSAPSS_VALUE:
+         case TPM_ALG_RSAPSS:
            retVal = PssEncode(&sigOut->signature.rsapss.sig.b,
                               sigOut->signature.rsapss.hash, &hIn->b, rand);
            break;
-         case ALG_RSASSA_VALUE:
+         case TPM_ALG_RSASSA:
            retVal = RSASSA_Encode(&sigOut->signature.rsassa.sig.b,
                                   sigOut->signature.rsassa.hash, &hIn->b);
            break;
@@ -1090,8 +1092,8 @@ CryptRsaValidateSignature(
     pAssert(key != NULL && sig != NULL && digest != NULL);
     switch(sig->sigAlg)
        {
-         case ALG_RSAPSS_VALUE:
-         case ALG_RSASSA_VALUE:
+         case TPM_ALG_RSAPSS:
+         case TPM_ALG_RSASSA:
            break;
          default:
            return TPM_RC_SCHEME;
@@ -1106,11 +1108,11 @@ CryptRsaValidateSignature(
        {
            switch(sig->sigAlg)
                {
-                 case ALG_RSAPSS_VALUE:
+                 case TPM_ALG_RSAPSS:
                    retVal = PssDecode(sig->signature.any.hashAlg, &digest->b,
                                       &sig->signature.rsassa.sig.b);
                    break;
-                 case ALG_RSASSA_VALUE:
+                 case TPM_ALG_RSASSA:
                    retVal = RSASSA_Decode(sig->signature.any.hashAlg, &digest->b,
                                           &sig->signature.rsassa.sig.b);
                    break;
@@ -1180,7 +1182,7 @@ CryptRsaGenerateKey(
        return TPM_RC_SUCCESS;
 #endif
     // Make sure that key generation has been tested
-    TEST(ALG_NULL_VALUE);
+    TEST(TPM_ALG_NULL);
 #if USE_OPENSSL_FUNCTIONS_RSA          // libtpms added begin
     if (rand == NULL)
         return OpenSSLCryptRsaGenerateKey(rsaKey, e, keySizeInBits);
@@ -1311,7 +1313,7 @@ CryptRsaEncrypt(
 
     switch(scheme->scheme)
        {
-          case ALG_NULL_VALUE:  // 'raw' encryption
+          case TPM_ALG_NULL:  // 'raw' encryption
            {
                INT32                 i;
                INT32                 dSize = dIn->size;
@@ -1334,11 +1336,11 @@ CryptRsaEncrypt(
             if (EVP_PKEY_CTX_set_rsa_padding(ctx, RSA_NO_PADDING) <= 0)
                 ERROR_RETURN(TPM_RC_FAILURE);
             break;
-          case ALG_RSAES_VALUE:
+          case TPM_ALG_RSAES:
             if (EVP_PKEY_CTX_set_rsa_padding(ctx, RSA_PKCS1_PADDING) <= 0)
                 ERROR_RETURN(TPM_RC_FAILURE);
             break;
-          case ALG_OAEP_VALUE:
+          case TPM_ALG_OAEP:
             digestname = GetDigestNameByHashAlg(scheme->details.oaep.hashAlg);
             if (digestname == NULL)
                 ERROR_RETURN(TPM_RC_VALUE);
@@ -1397,6 +1399,7 @@ CryptRsaDecrypt(
     const char            *digestname;
     size_t                 outlen;
     unsigned char         *tmp = NULL;
+    unsigned char          buffer[MAX_RSA_KEY_BYTES];
 
     // Make sure that the necessary parameters are provided
     pAssert(cIn != NULL && dOut != NULL && key != NULL);
@@ -1416,15 +1419,15 @@ CryptRsaDecrypt(
 
     switch(scheme->scheme)
        {
-         case ALG_NULL_VALUE:  // 'raw' encryption
+         case TPM_ALG_NULL:  // 'raw' encryption
             if (EVP_PKEY_CTX_set_rsa_padding(ctx, RSA_NO_PADDING) <= 0)
                 ERROR_RETURN(TPM_RC_FAILURE);
             break;
-         case ALG_RSAES_VALUE:
+         case TPM_ALG_RSAES:
             if (EVP_PKEY_CTX_set_rsa_padding(ctx, RSA_PKCS1_PADDING) <= 0)
                 ERROR_RETURN(TPM_RC_FAILURE);
             break;
-         case ALG_OAEP_VALUE:
+         case TPM_ALG_OAEP:
             digestname = GetDigestNameByHashAlg(scheme->details.oaep.hashAlg);
             if (digestname == NULL)
                 ERROR_RETURN(TPM_RC_VALUE);
@@ -1440,22 +1443,27 @@ CryptRsaDecrypt(
                 if (tmp == NULL)
                     ERROR_RETURN(TPM_RC_FAILURE);
                 memcpy(tmp, label->buffer, label->size);
-            }
 
-            if (EVP_PKEY_CTX_set0_rsa_oaep_label(ctx, tmp, label->size) <= 0)
-                ERROR_RETURN(TPM_RC_FAILURE);
-            tmp = NULL;
+                if (EVP_PKEY_CTX_set0_rsa_oaep_label(ctx, tmp, label->size) <= 0)
+                    ERROR_RETURN(TPM_RC_FAILURE);
+                tmp = NULL;
+            }
             break;
          default:
             ERROR_RETURN(TPM_RC_SCHEME);
             break;
        }
 
-    outlen = cIn->size;
-    if (EVP_PKEY_decrypt(ctx, dOut->buffer, &outlen,
+    /* cannot use cOut->buffer */
+    outlen = sizeof(buffer);
+    if (EVP_PKEY_decrypt(ctx, buffer, &outlen,
                          cIn->buffer, cIn->size) <= 0)
         ERROR_RETURN(TPM_RC_FAILURE);
 
+    if (outlen > dOut->size)
+        ERROR_RETURN(TPM_RC_FAILURE);
+
+    memcpy(dOut->buffer, buffer, outlen);
     dOut->size = outlen;
 
     retVal = TPM_RC_SUCCESS;
@@ -1496,14 +1504,14 @@ CryptRsaSign(
 
     switch(sigOut->sigAlg)
          {
-          case ALG_NULL_VALUE:
+          case TPM_ALG_NULL:
             sigOut->signature.rsapss.sig.t.size = 0;
             return TPM_RC_SUCCESS;
-          case ALG_RSAPSS_VALUE:
+          case TPM_ALG_RSAPSS:
             padding = RSA_PKCS1_PSS_PADDING;
             hashAlg = sigOut->signature.rsapss.hash;
             break;
-          case ALG_RSASSA_VALUE:
+          case TPM_ALG_RSASSA:
             padding = RSA_PKCS1_PADDING;
             hashAlg = sigOut->signature.rsassa.hash;
             break;
@@ -1532,6 +1540,16 @@ CryptRsaSign(
         EVP_PKEY_CTX_set_signature_md(ctx, md) <= 0)
         ERROR_RETURN(TPM_RC_FAILURE);
 
+    /* careful with PSS padding: Use salt length = hash length (-1) if
+     *   length(digest) + length(hash-to-sign) + 2 <= modSize
+     * otherwise use the max. possible salt length, which is the default (-2)
+     * test case: 1024 bit key PSS signing sha512 hash
+     */
+    if (padding == RSA_PKCS1_PSS_PADDING &&
+        EVP_MD_size(md) + hIn->b.size + 2 <= modSize && /* OSSL: RSA_padding_add_PKCS1_PSS_mgf1 */
+        EVP_PKEY_CTX_set_rsa_pss_saltlen(ctx, -1) <= 0)
+        ERROR_RETURN(TPM_RC_FAILURE);
+
     outlen = sigOut->signature.rsapss.sig.t.size;
     if (EVP_PKEY_sign(ctx,
                       sigOut->signature.rsapss.sig.t.buffer, &outlen,
@@ -1566,10 +1584,10 @@ CryptRsaValidateSignature(
     pAssert(key != NULL && sig != NULL && digest != NULL);
     switch(sig->sigAlg)
        {
-         case ALG_RSAPSS_VALUE:
+         case TPM_ALG_RSAPSS:
            padding = RSA_PKCS1_PSS_PADDING;
            break;
-         case ALG_RSASSA_VALUE:
+         case TPM_ALG_RSASSA:
            padding = RSA_PKCS1_PADDING;
            break;
          default: