xref: /freebsd/crypto/openssl/providers/implementations/kem/mlx_kem.c (revision e7be843b4a162e68651d3911f0357ed464915629)
1 /*
2  * Copyright 2024-2025 The OpenSSL Project Authors. All Rights Reserved.
3  *
4  * Licensed under the Apache License 2.0 (the "License").  You may not use
5  * this file except in compliance with the License.  You can obtain a copy
6  * in the file LICENSE in the source distribution or at
7  * https://www.openssl.org/source/license.html
8  */
9 
10 #include <openssl/core_dispatch.h>
11 #include <openssl/core_names.h>
12 #include <openssl/crypto.h>
13 #include <openssl/err.h>
14 #include <openssl/evp.h>
15 #include <openssl/params.h>
16 #include <openssl/proverr.h>
17 #include <openssl/rand.h>
18 #include "prov/implementations.h"
19 #include "prov/mlx_kem.h"
20 #include "prov/provider_ctx.h"
21 #include "prov/providercommon.h"
22 
23 static OSSL_FUNC_kem_newctx_fn mlx_kem_newctx;
24 static OSSL_FUNC_kem_freectx_fn mlx_kem_freectx;
25 static OSSL_FUNC_kem_encapsulate_init_fn mlx_kem_encapsulate_init;
26 static OSSL_FUNC_kem_encapsulate_fn mlx_kem_encapsulate;
27 static OSSL_FUNC_kem_decapsulate_init_fn mlx_kem_decapsulate_init;
28 static OSSL_FUNC_kem_decapsulate_fn mlx_kem_decapsulate;
29 static OSSL_FUNC_kem_set_ctx_params_fn mlx_kem_set_ctx_params;
30 static OSSL_FUNC_kem_settable_ctx_params_fn mlx_kem_settable_ctx_params;
31 
32 typedef struct {
33     OSSL_LIB_CTX *libctx;
34     MLX_KEY *key;
35     int op;
36 } PROV_MLX_KEM_CTX;
37 
mlx_kem_newctx(void * provctx)38 static void *mlx_kem_newctx(void *provctx)
39 {
40     PROV_MLX_KEM_CTX *ctx;
41 
42     if ((ctx = OPENSSL_malloc(sizeof(*ctx))) == NULL)
43         return NULL;
44 
45     ctx->libctx = PROV_LIBCTX_OF(provctx);
46     ctx->key = NULL;
47     ctx->op = 0;
48     return ctx;
49 }
50 
mlx_kem_freectx(void * vctx)51 static void mlx_kem_freectx(void *vctx)
52 {
53     OPENSSL_free(vctx);
54 }
55 
mlx_kem_init(void * vctx,int op,void * key,ossl_unused const OSSL_PARAM params[])56 static int mlx_kem_init(void *vctx, int op, void *key,
57                         ossl_unused const OSSL_PARAM params[])
58 {
59     PROV_MLX_KEM_CTX *ctx = vctx;
60 
61     if (!ossl_prov_is_running())
62         return 0;
63     ctx->key = key;
64     ctx->op = op;
65     return 1;
66 }
67 
68 static int
mlx_kem_encapsulate_init(void * vctx,void * vkey,const OSSL_PARAM params[])69 mlx_kem_encapsulate_init(void *vctx, void *vkey, const OSSL_PARAM params[])
70 {
71     MLX_KEY *key = vkey;
72 
73     if (!mlx_kem_have_pubkey(key)) {
74         ERR_raise(ERR_LIB_PROV, PROV_R_MISSING_KEY);
75         return 0;
76     }
77     return mlx_kem_init(vctx, EVP_PKEY_OP_ENCAPSULATE, key, params);
78 }
79 
80 static int
mlx_kem_decapsulate_init(void * vctx,void * vkey,const OSSL_PARAM params[])81 mlx_kem_decapsulate_init(void *vctx, void *vkey, const OSSL_PARAM params[])
82 {
83     MLX_KEY *key = vkey;
84 
85     if (!mlx_kem_have_prvkey(key)) {
86         ERR_raise(ERR_LIB_PROV, PROV_R_MISSING_KEY);
87         return 0;
88     }
89     return mlx_kem_init(vctx, EVP_PKEY_OP_DECAPSULATE, key, params);
90 }
91 
mlx_kem_settable_ctx_params(ossl_unused void * vctx,ossl_unused void * provctx)92 static const OSSL_PARAM *mlx_kem_settable_ctx_params(ossl_unused void *vctx,
93                                                      ossl_unused void *provctx)
94 {
95     static const OSSL_PARAM params[] = { OSSL_PARAM_END };
96 
97     return params;
98 }
99 
100 static int
mlx_kem_set_ctx_params(void * vctx,const OSSL_PARAM params[])101 mlx_kem_set_ctx_params(void *vctx, const OSSL_PARAM params[])
102 {
103     return 1;
104 }
105 
mlx_kem_encapsulate(void * vctx,unsigned char * ctext,size_t * clen,unsigned char * shsec,size_t * slen)106 static int mlx_kem_encapsulate(void *vctx, unsigned char *ctext, size_t *clen,
107                                unsigned char *shsec, size_t *slen)
108 {
109     MLX_KEY *key = ((PROV_MLX_KEM_CTX *) vctx)->key;
110     EVP_PKEY_CTX *ctx = NULL;
111     EVP_PKEY *xkey = NULL;
112     size_t encap_clen;
113     size_t encap_slen;
114     uint8_t *cbuf;
115     uint8_t *sbuf;
116     int ml_kem_slot = key->xinfo->ml_kem_slot;
117     int ret = 0;
118 
119     if (!mlx_kem_have_pubkey(key)) {
120         ERR_raise(ERR_LIB_PROV, PROV_R_MISSING_KEY);
121         goto end;
122     }
123     encap_clen = key->minfo->ctext_bytes + key->xinfo->pubkey_bytes;
124     encap_slen = ML_KEM_SHARED_SECRET_BYTES + key->xinfo->shsec_bytes;
125 
126     if (ctext == NULL) {
127         if (clen == NULL && slen == NULL)
128             return 0;
129         if (clen != NULL)
130             *clen = encap_clen;
131         if (slen != NULL)
132             *slen = encap_slen;
133         return 1;
134     }
135     if (shsec == NULL) {
136         ERR_raise_data(ERR_LIB_PROV, PROV_R_NULL_OUTPUT_BUFFER,
137                        "null shared-secret output buffer");
138         return 0;
139     }
140 
141     if (clen == NULL) {
142         ERR_raise_data(ERR_LIB_PROV, PROV_R_NULL_LENGTH_POINTER,
143                        "null ciphertext input/output length pointer");
144         return 0;
145     } else if (*clen < encap_clen) {
146         ERR_raise_data(ERR_LIB_PROV, PROV_R_OUTPUT_BUFFER_TOO_SMALL,
147                        "ciphertext buffer too small");
148         return 0;
149     } else {
150         *clen = encap_clen;
151     }
152 
153     if (slen == NULL) {
154         ERR_raise_data(ERR_LIB_PROV, PROV_R_NULL_LENGTH_POINTER,
155                        "null shared secret input/output length pointer");
156         return 0;
157     } else if (*slen < encap_slen) {
158         ERR_raise_data(ERR_LIB_PROV, PROV_R_OUTPUT_BUFFER_TOO_SMALL,
159                        "shared-secret buffer too small");
160         return 0;
161     } else {
162         *slen = encap_slen;
163     }
164 
165     /* ML-KEM encapsulation */
166     encap_clen = key->minfo->ctext_bytes;
167     encap_slen = ML_KEM_SHARED_SECRET_BYTES;
168     cbuf = ctext + ml_kem_slot * key->xinfo->pubkey_bytes;
169     sbuf = shsec + ml_kem_slot * key->xinfo->shsec_bytes;
170     ctx = EVP_PKEY_CTX_new_from_pkey(key->libctx, key->mkey, key->propq);
171     if (ctx == NULL
172         || EVP_PKEY_encapsulate_init(ctx, NULL) <= 0
173         || EVP_PKEY_encapsulate(ctx, cbuf, &encap_clen, sbuf, &encap_slen) <= 0)
174         goto end;
175     if (encap_clen != key->minfo->ctext_bytes) {
176         ERR_raise_data(ERR_LIB_PROV, ERR_R_INTERNAL_ERROR,
177                        "unexpected %s ciphertext output size: %lu",
178                        key->minfo->algorithm_name, (unsigned long) encap_clen);
179         goto end;
180     }
181     if (encap_slen != ML_KEM_SHARED_SECRET_BYTES) {
182         ERR_raise_data(ERR_LIB_PROV, ERR_R_INTERNAL_ERROR,
183                        "unexpected %s shared secret output size: %lu",
184                        key->minfo->algorithm_name, (unsigned long) encap_slen);
185         goto end;
186     }
187     EVP_PKEY_CTX_free(ctx);
188 
189     /*-
190      * ECDHE encapsulation
191      *
192      * Generate own ephemeral private key and add its public key to ctext.
193      *
194      * Note, we could support a settable parameter that sets an extant ECDH
195      * keypair as the keys to use in encap, making it possible to reuse the
196      * same (TLS client) ECDHE keypair for both the classical EC keyshare and a
197      * corresponding ECDHE + ML-KEM keypair.  But the TLS layer would then need
198      * know that this is a hybrid, and that it can partly reuse the same keys
199      * as another group for which a keyshare will be sent.  Deferred until we
200      * support generating multiple keyshares, there's a workable keyshare
201      * prediction specification, and the optimisation is justified.
202      */
203     cbuf = ctext + (1 - ml_kem_slot) * key->minfo->ctext_bytes;
204     encap_clen = key->xinfo->pubkey_bytes;
205     ctx = EVP_PKEY_CTX_new_from_pkey(key->libctx, key->xkey, key->propq);
206     if (ctx == NULL
207         || EVP_PKEY_keygen_init(ctx) <= 0
208         || EVP_PKEY_keygen(ctx, &xkey) <= 0
209         || EVP_PKEY_get_octet_string_param(xkey, OSSL_PKEY_PARAM_ENCODED_PUBLIC_KEY,
210                                            cbuf, encap_clen, &encap_clen) <= 0)
211         goto end;
212     if (encap_clen != key->xinfo->pubkey_bytes) {
213         ERR_raise_data(ERR_LIB_PROV, ERR_R_INTERNAL_ERROR,
214                        "unexpected %s public key output size: %lu",
215                        key->xinfo->algorithm_name, (unsigned long) encap_clen);
216         goto end;
217     }
218     EVP_PKEY_CTX_free(ctx);
219 
220     /* Derive the ECDH shared secret */
221     encap_slen = key->xinfo->shsec_bytes;
222     sbuf = shsec + (1 - ml_kem_slot) * ML_KEM_SHARED_SECRET_BYTES;
223     ctx = EVP_PKEY_CTX_new_from_pkey(key->libctx, xkey, key->propq);
224     if (ctx == NULL
225         || EVP_PKEY_derive_init(ctx) <= 0
226         || EVP_PKEY_derive_set_peer(ctx, key->xkey) <= 0
227         || EVP_PKEY_derive(ctx, sbuf, &encap_slen) <= 0)
228         goto end;
229     if (encap_slen != key->xinfo->shsec_bytes) {
230         ERR_raise_data(ERR_LIB_PROV, ERR_R_INTERNAL_ERROR,
231                        "unexpected %s shared secret output size: %lu",
232                        key->xinfo->algorithm_name, (unsigned long) encap_slen);
233         goto end;
234     }
235 
236     ret = 1;
237  end:
238     EVP_PKEY_free(xkey);
239     EVP_PKEY_CTX_free(ctx);
240     return ret;
241 }
242 
mlx_kem_decapsulate(void * vctx,uint8_t * shsec,size_t * slen,const uint8_t * ctext,size_t clen)243 static int mlx_kem_decapsulate(void *vctx, uint8_t *shsec, size_t *slen,
244                                const uint8_t *ctext, size_t clen)
245 {
246     MLX_KEY *key = ((PROV_MLX_KEM_CTX *) vctx)->key;
247     EVP_PKEY_CTX *ctx = NULL;
248     EVP_PKEY *xkey = NULL;
249     const uint8_t *cbuf;
250     uint8_t *sbuf;
251     size_t decap_slen = ML_KEM_SHARED_SECRET_BYTES + key->xinfo->shsec_bytes;
252     size_t decap_clen = key->minfo->ctext_bytes + key->xinfo->pubkey_bytes;
253     int ml_kem_slot = key->xinfo->ml_kem_slot;
254     int ret = 0;
255 
256     if (!mlx_kem_have_prvkey(key)) {
257         ERR_raise(ERR_LIB_PROV, PROV_R_MISSING_KEY);
258         return 0;
259     }
260 
261     if (shsec == NULL) {
262         if (slen == NULL)
263             return 0;
264         *slen = decap_slen;
265         return 1;
266     }
267 
268     /* For now tolerate newly-deprecated NULL length pointers. */
269     if (slen == NULL) {
270         slen = &decap_slen;
271     } else if (*slen < decap_slen) {
272         ERR_raise_data(ERR_LIB_PROV, PROV_R_OUTPUT_BUFFER_TOO_SMALL,
273                        "shared-secret buffer too small");
274         return 0;
275     } else {
276         *slen = decap_slen;
277     }
278     if (clen != decap_clen) {
279         ERR_raise_data(ERR_LIB_PROV, PROV_R_WRONG_CIPHERTEXT_SIZE,
280                        "wrong decapsulation input ciphertext size: %lu",
281                        (unsigned long) clen);
282         return 0;
283     }
284 
285     /* ML-KEM decapsulation */
286     decap_clen = key->minfo->ctext_bytes;
287     decap_slen = ML_KEM_SHARED_SECRET_BYTES;
288     cbuf = ctext + ml_kem_slot * key->xinfo->pubkey_bytes;
289     sbuf = shsec + ml_kem_slot * key->xinfo->shsec_bytes;
290     ctx = EVP_PKEY_CTX_new_from_pkey(key->libctx, key->mkey, key->propq);
291     if (ctx == NULL
292         || EVP_PKEY_decapsulate_init(ctx, NULL) <= 0
293         || EVP_PKEY_decapsulate(ctx, sbuf, &decap_slen, cbuf, decap_clen) <= 0)
294         goto end;
295     if (decap_slen != ML_KEM_SHARED_SECRET_BYTES) {
296         ERR_raise_data(ERR_LIB_PROV, ERR_R_INTERNAL_ERROR,
297                        "unexpected %s shared secret output size: %lu",
298                        key->minfo->algorithm_name, (unsigned long) decap_slen);
299         goto end;
300     }
301     EVP_PKEY_CTX_free(ctx);
302 
303     /* ECDH decapsulation */
304     decap_clen = key->xinfo->pubkey_bytes;
305     decap_slen = key->xinfo->shsec_bytes;
306     cbuf = ctext + (1 - ml_kem_slot) * key->minfo->ctext_bytes;
307     sbuf = shsec + (1 - ml_kem_slot) * ML_KEM_SHARED_SECRET_BYTES;
308     ctx = EVP_PKEY_CTX_new_from_pkey(key->libctx, key->xkey, key->propq);
309     if (ctx == NULL
310         || (xkey = EVP_PKEY_new()) == NULL
311         || EVP_PKEY_copy_parameters(xkey, key->xkey) <= 0
312         || EVP_PKEY_set1_encoded_public_key(xkey, cbuf, decap_clen) <= 0
313         || EVP_PKEY_derive_init(ctx) <= 0
314         || EVP_PKEY_derive_set_peer(ctx, xkey) <= 0
315         || EVP_PKEY_derive(ctx, sbuf, &decap_slen) <= 0)
316         goto end;
317     if (decap_slen != key->xinfo->shsec_bytes) {
318         ERR_raise_data(ERR_LIB_PROV, ERR_R_INTERNAL_ERROR,
319                        "unexpected %s shared secret output size: %lu",
320                        key->xinfo->algorithm_name, (unsigned long) decap_slen);
321         goto end;
322     }
323 
324     ret = 1;
325  end:
326     EVP_PKEY_CTX_free(ctx);
327     EVP_PKEY_free(xkey);
328     return ret;
329 }
330 
331 const OSSL_DISPATCH ossl_mlx_kem_asym_kem_functions[] = {
332     { OSSL_FUNC_KEM_NEWCTX, (OSSL_FUNC) mlx_kem_newctx },
333     { OSSL_FUNC_KEM_ENCAPSULATE_INIT, (OSSL_FUNC) mlx_kem_encapsulate_init },
334     { OSSL_FUNC_KEM_ENCAPSULATE, (OSSL_FUNC) mlx_kem_encapsulate },
335     { OSSL_FUNC_KEM_DECAPSULATE_INIT, (OSSL_FUNC) mlx_kem_decapsulate_init },
336     { OSSL_FUNC_KEM_DECAPSULATE, (OSSL_FUNC) mlx_kem_decapsulate },
337     { OSSL_FUNC_KEM_FREECTX, (OSSL_FUNC) mlx_kem_freectx },
338     { OSSL_FUNC_KEM_SET_CTX_PARAMS, (OSSL_FUNC) mlx_kem_set_ctx_params },
339     { OSSL_FUNC_KEM_SETTABLE_CTX_PARAMS, (OSSL_FUNC) mlx_kem_settable_ctx_params },
340     OSSL_DISPATCH_END
341 };
342