From 1ccc320ee99651622ced9b33764d5e7890ca3f57 Mon Sep 17 00:00:00 2001 From: Michael Brown Date: Tue, 2 Dec 2025 13:12:25 +0000 Subject: [PATCH] [crypto] Construct asymmetric ciphered data using ASN.1 builders Signed-off-by: Michael Brown --- src/crypto/cms.c | 24 +++++++------- src/crypto/crypto_null.c | 10 +++--- src/crypto/rsa.c | 65 +++++++++++++++++++++---------------- src/include/ipxe/crypto.h | 34 +++++++++++--------- src/net/tls.c | 68 ++++++++++++++++++++++----------------- src/tests/pubkey_test.c | 64 +++++++++++++++++++----------------- src/tests/pubkey_test.h | 20 ++++++------ 7 files changed, 156 insertions(+), 129 deletions(-) diff --git a/src/crypto/cms.c b/src/crypto/cms.c index a3c03a9b4..7775e581b 100644 --- a/src/crypto/cms.c +++ b/src/crypto/cms.c @@ -917,29 +917,26 @@ static int cms_cipher_key ( struct cms_message *cms, struct pubkey_algorithm *pubkey = part->pubkey; const struct asn1_cursor *key = privkey_cursor ( private_key ); const struct asn1_cursor *value = &part->value; - size_t max_len = pubkey_max_len ( pubkey, key ); - uint8_t cipher_key[max_len]; - int len; + struct asn1_builder cipher_key = { NULL, 0 }; int rc; /* Decrypt cipher key */ - len = pubkey_decrypt ( pubkey, key, value->data, value->len, - cipher_key ); - if ( len < 0 ) { - rc = len; + if ( ( rc = pubkey_decrypt ( pubkey, key, value, + &cipher_key ) ) != 0 ) { DBGC ( cms, "CMS %p/%p could not decrypt cipher key: %s\n", cms, part, strerror ( rc ) ); DBGC_HDA ( cms, 0, value->data, value->len ); - return rc; + goto err_decrypt; } DBGC ( cms, "CMS %p/%p cipher key:\n", cms, part ); - DBGC_HDA ( cms, 0, cipher_key, len ); + DBGC_HDA ( cms, 0, cipher_key.data, cipher_key.len ); /* Set cipher key */ - if ( ( rc = cipher_setkey ( cipher, ctx, cipher_key, len ) ) != 0 ) { + if ( ( rc = cipher_setkey ( cipher, ctx, cipher_key.data, + cipher_key.len ) ) != 0 ) { DBGC ( cms, "CMS %p could not set cipher key: %s\n", cms, strerror ( rc ) ); - return rc; + goto err_setkey; } /* Set cipher initialization vector */ @@ -949,7 +946,10 @@ static int cms_cipher_key ( struct cms_message *cms, DBGC_HDA ( cms, 0, cms->iv.data, cms->iv.len ); } - return 0; + err_setkey: + err_decrypt: + free ( cipher_key.data ); + return rc; } /** diff --git a/src/crypto/crypto_null.c b/src/crypto/crypto_null.c index ee948e00d..e8f8cbde8 100644 --- a/src/crypto/crypto_null.c +++ b/src/crypto/crypto_null.c @@ -98,16 +98,14 @@ size_t pubkey_null_max_len ( const struct asn1_cursor *key __unused ) { } int pubkey_null_encrypt ( const struct asn1_cursor *key __unused, - const void *plaintext __unused, - size_t plaintext_len __unused, - void *ciphertext __unused ) { + const struct asn1_cursor *plaintext __unused, + struct asn1_builder *ciphertext __unused ) { return 0; } int pubkey_null_decrypt ( const struct asn1_cursor *key __unused, - const void *ciphertext __unused, - size_t ciphertext_len __unused, - void *plaintext __unused ) { + const struct asn1_cursor *ciphertext __unused, + struct asn1_builder *plaintext __unused ) { return 0; } diff --git a/src/crypto/rsa.c b/src/crypto/rsa.c index fd6a1ef39..18b2b1c14 100644 --- a/src/crypto/rsa.c +++ b/src/crypto/rsa.c @@ -338,12 +338,12 @@ static void rsa_cipher ( struct rsa_context *context, * * @v key Key * @v plaintext Plaintext - * @v plaintext_len Length of plaintext * @v ciphertext Ciphertext * @ret ciphertext_len Length of ciphertext, or negative error */ -static int rsa_encrypt ( const struct asn1_cursor *key, const void *plaintext, - size_t plaintext_len, void *ciphertext ) { +static int rsa_encrypt ( const struct asn1_cursor *key, + const struct asn1_cursor *plaintext, + struct asn1_builder *ciphertext ) { struct rsa_context context; void *temp; uint8_t *encoded; @@ -352,7 +352,7 @@ static int rsa_encrypt ( const struct asn1_cursor *key, const void *plaintext, int rc; DBGC ( &context, "RSA %p encrypting:\n", &context ); - DBGC_HDA ( &context, 0, plaintext, plaintext_len ); + DBGC_HDA ( &context, 0, plaintext->data, plaintext->len ); /* Initialise context */ if ( ( rc = rsa_init ( &context, key ) ) != 0 ) @@ -360,12 +360,12 @@ static int rsa_encrypt ( const struct asn1_cursor *key, const void *plaintext, /* Calculate lengths */ max_len = ( context.max_len - 11 ); - random_nz_len = ( max_len - plaintext_len + 8 ); + random_nz_len = ( max_len - plaintext->len + 8 ); /* Sanity check */ - if ( plaintext_len > max_len ) { + if ( plaintext->len > max_len ) { DBGC ( &context, "RSA %p plaintext too long (%zd bytes, max " - "%zd)\n", &context, plaintext_len, max_len ); + "%zd)\n", &context, plaintext->len, max_len ); rc = -ERANGE; goto err_sanity; } @@ -383,19 +383,24 @@ static int rsa_encrypt ( const struct asn1_cursor *key, const void *plaintext, goto err_random; } encoded[ 2 + random_nz_len ] = 0x00; - memcpy ( &encoded[ context.max_len - plaintext_len ], - plaintext, plaintext_len ); + memcpy ( &encoded[ context.max_len - plaintext->len ], + plaintext->data, plaintext->len ); + + /* Create space for ciphertext */ + if ( ( rc = asn1_grow ( ciphertext, context.max_len ) ) != 0 ) + goto err_grow; /* Encipher the encoded message */ - rsa_cipher ( &context, encoded, ciphertext ); + rsa_cipher ( &context, encoded, ciphertext->data ); DBGC ( &context, "RSA %p encrypted:\n", &context ); - DBGC_HDA ( &context, 0, ciphertext, context.max_len ); + DBGC_HDA ( &context, 0, ciphertext->data, context.max_len ); /* Free context */ rsa_free ( &context ); - return context.max_len; + return 0; + err_grow: err_random: err_sanity: rsa_free ( &context ); @@ -408,33 +413,33 @@ static int rsa_encrypt ( const struct asn1_cursor *key, const void *plaintext, * * @v key Key * @v ciphertext Ciphertext - * @v ciphertext_len Ciphertext length * @v plaintext Plaintext - * @ret plaintext_len Plaintext length, or negative error + * @ret rc Return status code */ -static int rsa_decrypt ( const struct asn1_cursor *key, const void *ciphertext, - size_t ciphertext_len, void *plaintext ) { +static int rsa_decrypt ( const struct asn1_cursor *key, + const struct asn1_cursor *ciphertext, + struct asn1_builder *plaintext ) { struct rsa_context context; void *temp; uint8_t *encoded; uint8_t *end; uint8_t *zero; uint8_t *start; - size_t plaintext_len; + size_t len; int rc; DBGC ( &context, "RSA %p decrypting:\n", &context ); - DBGC_HDA ( &context, 0, ciphertext, ciphertext_len ); + DBGC_HDA ( &context, 0, ciphertext->data, ciphertext->len ); /* Initialise context */ if ( ( rc = rsa_init ( &context, key ) ) != 0 ) goto err_init; /* Sanity check */ - if ( ciphertext_len != context.max_len ) { + if ( ciphertext->len != context.max_len ) { DBGC ( &context, "RSA %p ciphertext incorrect length (%zd " "bytes, should be %zd)\n", - &context, ciphertext_len, context.max_len ); + &context, ciphertext->len, context.max_len ); rc = -ERANGE; goto err_sanity; } @@ -444,7 +449,7 @@ static int rsa_decrypt ( const struct asn1_cursor *key, const void *ciphertext, */ temp = context.input0; encoded = temp; - rsa_cipher ( &context, ciphertext, encoded ); + rsa_cipher ( &context, ciphertext->data, encoded ); /* Parse the message */ end = ( encoded + context.max_len ); @@ -454,25 +459,31 @@ static int rsa_decrypt ( const struct asn1_cursor *key, const void *ciphertext, } zero = memchr ( &encoded[2], 0, ( end - &encoded[2] ) ); if ( ! zero ) { + DBGC ( &context, "RSA %p invalid decrypted message:\n", + &context ); + DBGC_HDA ( &context, 0, encoded, context.max_len ); rc = -EINVAL; goto err_invalid; } start = ( zero + 1 ); - plaintext_len = ( end - start ); + len = ( end - start ); + + /* Create space for plaintext */ + if ( ( rc = asn1_grow ( plaintext, len ) ) != 0 ) + goto err_grow; /* Copy out message */ - memcpy ( plaintext, start, plaintext_len ); + memcpy ( plaintext->data, start, len ); DBGC ( &context, "RSA %p decrypted:\n", &context ); - DBGC_HDA ( &context, 0, plaintext, plaintext_len ); + DBGC_HDA ( &context, 0, plaintext->data, len ); /* Free context */ rsa_free ( &context ); - return plaintext_len; + return 0; + err_grow: err_invalid: - DBGC ( &context, "RSA %p invalid decrypted message:\n", &context ); - DBGC_HDA ( &context, 0, encoded, context.max_len ); err_sanity: rsa_free ( &context ); err_init: diff --git a/src/include/ipxe/crypto.h b/src/include/ipxe/crypto.h index c457a74b1..68bd23048 100644 --- a/src/include/ipxe/crypto.h +++ b/src/include/ipxe/crypto.h @@ -131,22 +131,22 @@ struct pubkey_algorithm { * * @v key Key * @v plaintext Plaintext - * @v plaintext_len Length of plaintext * @v ciphertext Ciphertext - * @ret ciphertext_len Length of ciphertext, or negative error + * @ret rc Return status code */ - int ( * encrypt ) ( const struct asn1_cursor *key, const void *data, - size_t len, void *out ); + int ( * encrypt ) ( const struct asn1_cursor *key, + const struct asn1_cursor *plaintext, + struct asn1_builder *ciphertext ); /** Decrypt * * @v key Key * @v ciphertext Ciphertext - * @v ciphertext_len Ciphertext length * @v plaintext Plaintext - * @ret plaintext_len Plaintext length, or negative error + * @ret rc Return status code */ - int ( * decrypt ) ( const struct asn1_cursor *key, const void *data, - size_t len, void *out ); + int ( * decrypt ) ( const struct asn1_cursor *key, + const struct asn1_cursor *ciphertext, + struct asn1_builder *plaintext ); /** Sign digest value * * @v key Key @@ -274,14 +274,16 @@ pubkey_max_len ( struct pubkey_algorithm *pubkey, static inline __attribute__ (( always_inline )) int pubkey_encrypt ( struct pubkey_algorithm *pubkey, const struct asn1_cursor *key, - const void *data, size_t len, void *out ) { - return pubkey->encrypt ( key, data, len, out ); + const struct asn1_cursor *plaintext, + struct asn1_builder *ciphertext ) { + return pubkey->encrypt ( key, plaintext, ciphertext ); } static inline __attribute__ (( always_inline )) int pubkey_decrypt ( struct pubkey_algorithm *pubkey, const struct asn1_cursor *key, - const void *data, size_t len, void *out ) { - return pubkey->decrypt ( key, data, len, out ); + const struct asn1_cursor *ciphertext, + struct asn1_builder *plaintext ) { + return pubkey->decrypt ( key, ciphertext, plaintext ); } static inline __attribute__ (( always_inline )) int @@ -325,11 +327,11 @@ extern void cipher_null_auth ( void *ctx, void *auth ); extern size_t pubkey_null_max_len ( const struct asn1_cursor *key ); extern int pubkey_null_encrypt ( const struct asn1_cursor *key, - const void *plaintext, size_t plaintext_len, - void *ciphertext ); + const struct asn1_cursor *plaintext, + struct asn1_builder *ciphertext ); extern int pubkey_null_decrypt ( const struct asn1_cursor *key, - const void *ciphertext, size_t ciphertext_len, - void *plaintext ); + const struct asn1_cursor *ciphertext, + struct asn1_builder *plaintext ); extern int pubkey_null_sign ( const struct asn1_cursor *key, struct digest_algorithm *digest, const void *value, diff --git a/src/net/tls.c b/src/net/tls.c index c01ce9515..6140ca58a 100644 --- a/src/net/tls.c +++ b/src/net/tls.c @@ -1416,59 +1416,69 @@ static int tls_send_certificate ( struct tls_connection *tls ) { static int tls_send_client_key_exchange_pubkey ( struct tls_connection *tls ) { struct tls_cipherspec *cipherspec = &tls->tx.cipherspec.pending; struct pubkey_algorithm *pubkey = cipherspec->suite->pubkey; - size_t max_len = pubkey_max_len ( pubkey, &tls->server.key ); struct { uint16_t version; uint8_t random[46]; } __attribute__ (( packed )) pre_master_secret; - struct { - uint32_t type_length; - uint16_t encrypted_pre_master_secret_len; - uint8_t encrypted_pre_master_secret[max_len]; - } __attribute__ (( packed )) key_xchg; - size_t unused; - int len; + struct asn1_cursor cursor = { + .data = &pre_master_secret, + .len = sizeof ( pre_master_secret ), + }; + struct asn1_builder builder = { NULL, 0 }; int rc; /* Generate pre-master secret */ pre_master_secret.version = htons ( TLS_VERSION_MAX ); if ( ( rc = tls_generate_random ( tls, &pre_master_secret.random, ( sizeof ( pre_master_secret.random ) ) ) ) != 0 ) { - return rc; + goto err_random; } /* Encrypt pre-master secret using server's public key */ - memset ( &key_xchg, 0, sizeof ( key_xchg ) ); - len = pubkey_encrypt ( pubkey, &tls->server.key, &pre_master_secret, - sizeof ( pre_master_secret ), - key_xchg.encrypted_pre_master_secret ); - if ( len < 0 ) { - rc = len; + if ( ( rc = pubkey_encrypt ( pubkey, &tls->server.key, &cursor, + &builder ) ) != 0 ) { DBGC ( tls, "TLS %p could not encrypt pre-master secret: %s\n", tls, strerror ( rc ) ); - return rc; + goto err_encrypt; + } + + /* Construct Client Key Exchange record */ + { + struct { + uint32_t type_length; + uint16_t encrypted_pre_master_secret_len; + } __attribute__ (( packed )) header; + + header.type_length = + ( cpu_to_le32 ( TLS_CLIENT_KEY_EXCHANGE ) | + htonl ( builder.len + sizeof ( header ) - + sizeof ( header.type_length ) ) ); + header.encrypted_pre_master_secret_len = htons ( builder.len ); + + if ( ( rc = asn1_prepend_raw ( &builder, &header, + sizeof ( header ) ) ) != 0 ) { + DBGC ( tls, "TLS %p could not construct Client Key " + "Exchange: %s\n", tls, strerror ( rc ) ); + goto err_prepend; + } } - unused = ( max_len - len ); - key_xchg.type_length = - ( cpu_to_le32 ( TLS_CLIENT_KEY_EXCHANGE ) | - htonl ( sizeof ( key_xchg ) - - sizeof ( key_xchg.type_length ) - unused ) ); - key_xchg.encrypted_pre_master_secret_len = - htons ( sizeof ( key_xchg.encrypted_pre_master_secret ) - - unused ); /* Transmit Client Key Exchange record */ - if ( ( rc = tls_send_handshake ( tls, &key_xchg, - ( sizeof ( key_xchg ) - - unused ) ) ) != 0 ) { - return rc; + if ( ( rc = tls_send_handshake ( tls, builder.data, + builder.len ) ) != 0 ) { + goto err_send; } /* Generate master secret */ tls_generate_master_secret ( tls, &pre_master_secret, sizeof ( pre_master_secret ) ); - return 0; + err_random: + err_encrypt: + err_prepend: + err_send: + free ( builder.data ); + return rc; } /** Public key exchange algorithm */ diff --git a/src/tests/pubkey_test.c b/src/tests/pubkey_test.c index e3fbc3b3f..d110b2946 100644 --- a/src/tests/pubkey_test.c +++ b/src/tests/pubkey_test.c @@ -50,41 +50,47 @@ FILE_LICENCE ( GPL2_OR_LATER_OR_UBDL ); void pubkey_okx ( struct pubkey_test *test, const char *file, unsigned int line ) { struct pubkey_algorithm *pubkey = test->pubkey; - size_t max_len = pubkey_max_len ( pubkey, &test->private ); - uint8_t encrypted[max_len]; - uint8_t decrypted[max_len]; - int encrypted_len; - int decrypted_len; + struct asn1_builder plaintext; + struct asn1_builder ciphertext; /* Test decrypting with private key to obtain known plaintext */ - decrypted_len = pubkey_decrypt ( pubkey, &test->private, - test->ciphertext, test->ciphertext_len, - decrypted ); - okx ( decrypted_len == ( ( int ) test->plaintext_len ), file, line ); - okx ( memcmp ( decrypted, test->plaintext, test->plaintext_len ) == 0, - file, line ); + plaintext.data = NULL; + plaintext.len = 0; + okx ( pubkey_decrypt ( pubkey, &test->private, &test->ciphertext, + &plaintext ) == 0, file, line ); + okx ( asn1_compare ( asn1_built ( &plaintext ), + &test->plaintext ) == 0, file, line ); + free ( plaintext.data ); /* Test encrypting with private key and decrypting with public key */ - encrypted_len = pubkey_encrypt ( pubkey, &test->private, - test->plaintext, test->plaintext_len, - encrypted ); - okx ( encrypted_len >= 0, file, line ); - decrypted_len = pubkey_decrypt ( pubkey, &test->public, encrypted, - encrypted_len, decrypted ); - okx ( decrypted_len == ( ( int ) test->plaintext_len ), file, line ); - okx ( memcmp ( decrypted, test->plaintext, test->plaintext_len ) == 0, - file, line ); + ciphertext.data = NULL; + ciphertext.len = 0; + plaintext.data = NULL; + plaintext.len = 0; + okx ( pubkey_encrypt ( pubkey, &test->private, &test->plaintext, + &ciphertext ) == 0, file, line ); + okx ( pubkey_decrypt ( pubkey, &test->public, + asn1_built ( &ciphertext ), + &plaintext ) == 0, file, line ); + okx ( asn1_compare ( asn1_built ( &plaintext ), + &test->plaintext ) == 0, file, line ); + free ( ciphertext.data ); + free ( plaintext.data ); /* Test encrypting with public key and decrypting with private key */ - encrypted_len = pubkey_encrypt ( pubkey, &test->public, - test->plaintext, test->plaintext_len, - encrypted ); - okx ( encrypted_len >= 0, file, line ); - decrypted_len = pubkey_decrypt ( pubkey, &test->private, encrypted, - encrypted_len, decrypted ); - okx ( decrypted_len == ( ( int ) test->plaintext_len ), file, line ); - okx ( memcmp ( decrypted, test->plaintext, test->plaintext_len ) == 0, - file, line ); + ciphertext.data = NULL; + ciphertext.len = 0; + plaintext.data = NULL; + plaintext.len = 0; + okx ( pubkey_encrypt ( pubkey, &test->public, &test->plaintext, + &ciphertext ) == 0, file, line ); + okx ( pubkey_decrypt ( pubkey, &test->private, + asn1_built ( &ciphertext ), + &plaintext ) == 0, file, line ); + okx ( asn1_compare ( asn1_built ( &plaintext ), + &test->plaintext ) == 0, file, line ); + free ( ciphertext.data ); + free ( plaintext.data ); } /** diff --git a/src/tests/pubkey_test.h b/src/tests/pubkey_test.h index 1bb6caf51..33b301a6e 100644 --- a/src/tests/pubkey_test.h +++ b/src/tests/pubkey_test.h @@ -16,18 +16,14 @@ struct pubkey_test { /** Public key */ const struct asn1_cursor public; /** Plaintext */ - const void *plaintext; - /** Length of plaintext */ - size_t plaintext_len; + const struct asn1_cursor plaintext; /** Ciphertext * * Note that the encryption process may include some random * padding, so a given plaintext will encrypt to multiple * different ciphertexts. */ - const void *ciphertext; - /** Length of ciphertext */ - size_t ciphertext_len; + const struct asn1_cursor ciphertext; }; /** A public-key signature test */ @@ -90,10 +86,14 @@ struct pubkey_sign_test { .data = name ## _public, \ .len = sizeof ( name ## _public ), \ }, \ - .plaintext = name ## _plaintext, \ - .plaintext_len = sizeof ( name ## _plaintext ), \ - .ciphertext = name ## _ciphertext, \ - .ciphertext_len = sizeof ( name ## _ciphertext ), \ + .plaintext = { \ + .data = name ## _plaintext, \ + .len = sizeof ( name ## _plaintext ), \ + }, \ + .ciphertext = { \ + .data = name ## _ciphertext, \ + .len = sizeof ( name ## _ciphertext ), \ + }, \ } /**