xref: /freebsd/crypto/openssl/test/helpers/ssltestlib.c (revision 35c0a8c449fd2b7f75029ebed5e10852240f0865)
1 /*
2  * Copyright 2016-2024 The OpenSSL Project Authors. All Rights Reserved.
3  *
4  * Licensed under the Apache License 2.0 (the "License").  You may not use
5  * this file except in compliance with the License.  You can obtain a copy
6  * in the file LICENSE in the source distribution or at
7  * https://www.openssl.org/source/license.html
8  */
9 
10 /*
11  * We need access to the deprecated low level ENGINE APIs for legacy purposes
12  * when the deprecated calls are not hidden
13  */
14 #ifndef OPENSSL_NO_DEPRECATED_3_0
15 # define OPENSSL_SUPPRESS_DEPRECATED
16 #endif
17 
18 #include <string.h>
19 
20 #include <openssl/engine.h>
21 #include "internal/nelem.h"
22 #include "ssltestlib.h"
23 #include "../testutil.h"
24 #include "e_os.h" /* for ossl_sleep() etc. */
25 
26 #ifdef OPENSSL_SYS_UNIX
27 # include <unistd.h>
28 # ifndef OPENSSL_NO_KTLS
29 #  include <netinet/in.h>
30 #  include <netinet/in.h>
31 #  include <arpa/inet.h>
32 #  include <sys/socket.h>
33 #  include <unistd.h>
34 #  include <fcntl.h>
35 # endif
36 #endif
37 
38 static int tls_dump_new(BIO *bi);
39 static int tls_dump_free(BIO *a);
40 static int tls_dump_read(BIO *b, char *out, int outl);
41 static int tls_dump_write(BIO *b, const char *in, int inl);
42 static long tls_dump_ctrl(BIO *b, int cmd, long num, void *ptr);
43 static int tls_dump_gets(BIO *bp, char *buf, int size);
44 static int tls_dump_puts(BIO *bp, const char *str);
45 
46 /* Choose a sufficiently large type likely to be unused for this custom BIO */
47 #define BIO_TYPE_TLS_DUMP_FILTER  (0x80 | BIO_TYPE_FILTER)
48 #define BIO_TYPE_MEMPACKET_TEST    0x81
49 #define BIO_TYPE_ALWAYS_RETRY      0x82
50 
51 static BIO_METHOD *method_tls_dump = NULL;
52 static BIO_METHOD *meth_mem = NULL;
53 static BIO_METHOD *meth_always_retry = NULL;
54 static int retry_err = -1;
55 
56 /* Note: Not thread safe! */
57 const BIO_METHOD *bio_f_tls_dump_filter(void)
58 {
59     if (method_tls_dump == NULL) {
60         method_tls_dump = BIO_meth_new(BIO_TYPE_TLS_DUMP_FILTER,
61                                         "TLS dump filter");
62         if (   method_tls_dump == NULL
63             || !BIO_meth_set_write(method_tls_dump, tls_dump_write)
64             || !BIO_meth_set_read(method_tls_dump, tls_dump_read)
65             || !BIO_meth_set_puts(method_tls_dump, tls_dump_puts)
66             || !BIO_meth_set_gets(method_tls_dump, tls_dump_gets)
67             || !BIO_meth_set_ctrl(method_tls_dump, tls_dump_ctrl)
68             || !BIO_meth_set_create(method_tls_dump, tls_dump_new)
69             || !BIO_meth_set_destroy(method_tls_dump, tls_dump_free))
70             return NULL;
71     }
72     return method_tls_dump;
73 }
74 
75 void bio_f_tls_dump_filter_free(void)
76 {
77     BIO_meth_free(method_tls_dump);
78 }
79 
80 static int tls_dump_new(BIO *bio)
81 {
82     BIO_set_init(bio, 1);
83     return 1;
84 }
85 
86 static int tls_dump_free(BIO *bio)
87 {
88     BIO_set_init(bio, 0);
89 
90     return 1;
91 }
92 
93 static void copy_flags(BIO *bio)
94 {
95     int flags;
96     BIO *next = BIO_next(bio);
97 
98     flags = BIO_test_flags(next, BIO_FLAGS_SHOULD_RETRY | BIO_FLAGS_RWS);
99     BIO_clear_flags(bio, BIO_FLAGS_SHOULD_RETRY | BIO_FLAGS_RWS);
100     BIO_set_flags(bio, flags);
101 }
102 
103 #define RECORD_CONTENT_TYPE     0
104 #define RECORD_VERSION_HI       1
105 #define RECORD_VERSION_LO       2
106 #define RECORD_EPOCH_HI         3
107 #define RECORD_EPOCH_LO         4
108 #define RECORD_SEQUENCE_START   5
109 #define RECORD_SEQUENCE_END     10
110 #define RECORD_LEN_HI           11
111 #define RECORD_LEN_LO           12
112 
113 #define MSG_TYPE                0
114 #define MSG_LEN_HI              1
115 #define MSG_LEN_MID             2
116 #define MSG_LEN_LO              3
117 #define MSG_SEQ_HI              4
118 #define MSG_SEQ_LO              5
119 #define MSG_FRAG_OFF_HI         6
120 #define MSG_FRAG_OFF_MID        7
121 #define MSG_FRAG_OFF_LO         8
122 #define MSG_FRAG_LEN_HI         9
123 #define MSG_FRAG_LEN_MID        10
124 #define MSG_FRAG_LEN_LO         11
125 
126 
127 static void dump_data(const char *data, int len)
128 {
129     int rem, i, content, reclen, msglen, fragoff, fraglen, epoch;
130     unsigned char *rec;
131 
132     printf("---- START OF PACKET ----\n");
133 
134     rem = len;
135     rec = (unsigned char *)data;
136 
137     while (rem > 0) {
138         if (rem != len)
139             printf("*\n");
140         printf("*---- START OF RECORD ----\n");
141         if (rem < DTLS1_RT_HEADER_LENGTH) {
142             printf("*---- RECORD TRUNCATED ----\n");
143             break;
144         }
145         content = rec[RECORD_CONTENT_TYPE];
146         printf("** Record Content-type: %d\n", content);
147         printf("** Record Version: %02x%02x\n",
148                rec[RECORD_VERSION_HI], rec[RECORD_VERSION_LO]);
149         epoch = (rec[RECORD_EPOCH_HI] << 8) | rec[RECORD_EPOCH_LO];
150         printf("** Record Epoch: %d\n", epoch);
151         printf("** Record Sequence: ");
152         for (i = RECORD_SEQUENCE_START; i <= RECORD_SEQUENCE_END; i++)
153             printf("%02x", rec[i]);
154         reclen = (rec[RECORD_LEN_HI] << 8) | rec[RECORD_LEN_LO];
155         printf("\n** Record Length: %d\n", reclen);
156 
157         /* Now look at message */
158         rec += DTLS1_RT_HEADER_LENGTH;
159         rem -= DTLS1_RT_HEADER_LENGTH;
160         if (content == SSL3_RT_HANDSHAKE) {
161             printf("**---- START OF HANDSHAKE MESSAGE FRAGMENT ----\n");
162             if (epoch > 0) {
163                 printf("**---- HANDSHAKE MESSAGE FRAGMENT ENCRYPTED ----\n");
164             } else if (rem < DTLS1_HM_HEADER_LENGTH
165                     || reclen < DTLS1_HM_HEADER_LENGTH) {
166                 printf("**---- HANDSHAKE MESSAGE FRAGMENT TRUNCATED ----\n");
167             } else {
168                 printf("*** Message Type: %d\n", rec[MSG_TYPE]);
169                 msglen = (rec[MSG_LEN_HI] << 16) | (rec[MSG_LEN_MID] << 8)
170                          | rec[MSG_LEN_LO];
171                 printf("*** Message Length: %d\n", msglen);
172                 printf("*** Message sequence: %d\n",
173                        (rec[MSG_SEQ_HI] << 8) | rec[MSG_SEQ_LO]);
174                 fragoff = (rec[MSG_FRAG_OFF_HI] << 16)
175                           | (rec[MSG_FRAG_OFF_MID] << 8)
176                           | rec[MSG_FRAG_OFF_LO];
177                 printf("*** Message Fragment offset: %d\n", fragoff);
178                 fraglen = (rec[MSG_FRAG_LEN_HI] << 16)
179                           | (rec[MSG_FRAG_LEN_MID] << 8)
180                           | rec[MSG_FRAG_LEN_LO];
181                 printf("*** Message Fragment len: %d\n", fraglen);
182                 if (fragoff + fraglen > msglen)
183                     printf("***---- HANDSHAKE MESSAGE FRAGMENT INVALID ----\n");
184                 else if (reclen < fraglen)
185                     printf("**---- HANDSHAKE MESSAGE FRAGMENT TRUNCATED ----\n");
186                 else
187                     printf("**---- END OF HANDSHAKE MESSAGE FRAGMENT ----\n");
188             }
189         }
190         if (rem < reclen) {
191             printf("*---- RECORD TRUNCATED ----\n");
192             rem = 0;
193         } else {
194             rec += reclen;
195             rem -= reclen;
196             printf("*---- END OF RECORD ----\n");
197         }
198     }
199     printf("---- END OF PACKET ----\n\n");
200     fflush(stdout);
201 }
202 
203 static int tls_dump_read(BIO *bio, char *out, int outl)
204 {
205     int ret;
206     BIO *next = BIO_next(bio);
207 
208     ret = BIO_read(next, out, outl);
209     copy_flags(bio);
210 
211     if (ret > 0) {
212         dump_data(out, ret);
213     }
214 
215     return ret;
216 }
217 
218 static int tls_dump_write(BIO *bio, const char *in, int inl)
219 {
220     int ret;
221     BIO *next = BIO_next(bio);
222 
223     ret = BIO_write(next, in, inl);
224     copy_flags(bio);
225 
226     return ret;
227 }
228 
229 static long tls_dump_ctrl(BIO *bio, int cmd, long num, void *ptr)
230 {
231     long ret;
232     BIO *next = BIO_next(bio);
233 
234     if (next == NULL)
235         return 0;
236 
237     switch (cmd) {
238     case BIO_CTRL_DUP:
239         ret = 0L;
240         break;
241     default:
242         ret = BIO_ctrl(next, cmd, num, ptr);
243         break;
244     }
245     return ret;
246 }
247 
248 static int tls_dump_gets(BIO *bio, char *buf, int size)
249 {
250     /* We don't support this - not needed anyway */
251     return -1;
252 }
253 
254 static int tls_dump_puts(BIO *bio, const char *str)
255 {
256     return tls_dump_write(bio, str, strlen(str));
257 }
258 
259 
260 struct mempacket_st {
261     unsigned char *data;
262     int len;
263     unsigned int num;
264     unsigned int type;
265 };
266 
267 static void mempacket_free(MEMPACKET *pkt)
268 {
269     if (pkt->data != NULL)
270         OPENSSL_free(pkt->data);
271     OPENSSL_free(pkt);
272 }
273 
274 typedef struct mempacket_test_ctx_st {
275     STACK_OF(MEMPACKET) *pkts;
276     unsigned int epoch;
277     unsigned int currrec;
278     unsigned int currpkt;
279     unsigned int lastpkt;
280     unsigned int injected;
281     unsigned int noinject;
282     unsigned int dropepoch;
283     int droprec;
284     int duprec;
285 } MEMPACKET_TEST_CTX;
286 
287 static int mempacket_test_new(BIO *bi);
288 static int mempacket_test_free(BIO *a);
289 static int mempacket_test_read(BIO *b, char *out, int outl);
290 static int mempacket_test_write(BIO *b, const char *in, int inl);
291 static long mempacket_test_ctrl(BIO *b, int cmd, long num, void *ptr);
292 static int mempacket_test_gets(BIO *bp, char *buf, int size);
293 static int mempacket_test_puts(BIO *bp, const char *str);
294 
295 const BIO_METHOD *bio_s_mempacket_test(void)
296 {
297     if (meth_mem == NULL) {
298         if (!TEST_ptr(meth_mem = BIO_meth_new(BIO_TYPE_MEMPACKET_TEST,
299                                               "Mem Packet Test"))
300             || !TEST_true(BIO_meth_set_write(meth_mem, mempacket_test_write))
301             || !TEST_true(BIO_meth_set_read(meth_mem, mempacket_test_read))
302             || !TEST_true(BIO_meth_set_puts(meth_mem, mempacket_test_puts))
303             || !TEST_true(BIO_meth_set_gets(meth_mem, mempacket_test_gets))
304             || !TEST_true(BIO_meth_set_ctrl(meth_mem, mempacket_test_ctrl))
305             || !TEST_true(BIO_meth_set_create(meth_mem, mempacket_test_new))
306             || !TEST_true(BIO_meth_set_destroy(meth_mem, mempacket_test_free)))
307             return NULL;
308     }
309     return meth_mem;
310 }
311 
312 void bio_s_mempacket_test_free(void)
313 {
314     BIO_meth_free(meth_mem);
315 }
316 
317 static int mempacket_test_new(BIO *bio)
318 {
319     MEMPACKET_TEST_CTX *ctx;
320 
321     if (!TEST_ptr(ctx = OPENSSL_zalloc(sizeof(*ctx))))
322         return 0;
323     if (!TEST_ptr(ctx->pkts = sk_MEMPACKET_new_null())) {
324         OPENSSL_free(ctx);
325         return 0;
326     }
327     ctx->dropepoch = 0;
328     ctx->droprec = -1;
329     BIO_set_init(bio, 1);
330     BIO_set_data(bio, ctx);
331     return 1;
332 }
333 
334 static int mempacket_test_free(BIO *bio)
335 {
336     MEMPACKET_TEST_CTX *ctx = BIO_get_data(bio);
337 
338     sk_MEMPACKET_pop_free(ctx->pkts, mempacket_free);
339     OPENSSL_free(ctx);
340     BIO_set_data(bio, NULL);
341     BIO_set_init(bio, 0);
342     return 1;
343 }
344 
345 /* Record Header values */
346 #define EPOCH_HI        3
347 #define EPOCH_LO        4
348 #define RECORD_SEQUENCE 10
349 #define RECORD_LEN_HI   11
350 #define RECORD_LEN_LO   12
351 
352 #define STANDARD_PACKET                 0
353 
354 static int mempacket_test_read(BIO *bio, char *out, int outl)
355 {
356     MEMPACKET_TEST_CTX *ctx = BIO_get_data(bio);
357     MEMPACKET *thispkt;
358     unsigned char *rec;
359     int rem;
360     unsigned int seq, offset, len, epoch;
361 
362     BIO_clear_retry_flags(bio);
363     if ((thispkt = sk_MEMPACKET_value(ctx->pkts, 0)) == NULL
364         || thispkt->num != ctx->currpkt) {
365         /* Probably run out of data */
366         BIO_set_retry_read(bio);
367         return -1;
368     }
369     (void)sk_MEMPACKET_shift(ctx->pkts);
370     ctx->currpkt++;
371 
372     if (outl > thispkt->len)
373         outl = thispkt->len;
374 
375     if (thispkt->type != INJECT_PACKET_IGNORE_REC_SEQ
376             && (ctx->injected || ctx->droprec >= 0)) {
377         /*
378          * Overwrite the record sequence number. We strictly number them in
379          * the order received. Since we are actually a reliable transport
380          * we know that there won't be any re-ordering. We overwrite to deal
381          * with any packets that have been injected
382          */
383         for (rem = thispkt->len, rec = thispkt->data; rem > 0; rem -= len) {
384             if (rem < DTLS1_RT_HEADER_LENGTH)
385                 return -1;
386             epoch = (rec[EPOCH_HI] << 8) | rec[EPOCH_LO];
387             if (epoch != ctx->epoch) {
388                 ctx->epoch = epoch;
389                 ctx->currrec = 0;
390             }
391             seq = ctx->currrec;
392             offset = 0;
393             do {
394                 rec[RECORD_SEQUENCE - offset] = seq & 0xFF;
395                 seq >>= 8;
396                 offset++;
397             } while (seq > 0);
398 
399             len = ((rec[RECORD_LEN_HI] << 8) | rec[RECORD_LEN_LO])
400                   + DTLS1_RT_HEADER_LENGTH;
401             if (rem < (int)len)
402                 return -1;
403             if (ctx->droprec == (int)ctx->currrec && ctx->dropepoch == epoch) {
404                 if (rem > (int)len)
405                     memmove(rec, rec + len, rem - len);
406                 outl -= len;
407                 ctx->droprec = -1;
408                 if (outl == 0)
409                     BIO_set_retry_read(bio);
410             } else {
411                 rec += len;
412             }
413 
414             ctx->currrec++;
415         }
416     }
417 
418     memcpy(out, thispkt->data, outl);
419     mempacket_free(thispkt);
420     return outl;
421 }
422 
423 /*
424  * Look for records from different epochs in the last datagram and swap them
425  * around
426  */
427 int mempacket_swap_epoch(BIO *bio)
428 {
429     MEMPACKET_TEST_CTX *ctx = BIO_get_data(bio);
430     MEMPACKET *thispkt;
431     int rem, len, prevlen = 0, pktnum;
432     unsigned char *rec, *prevrec = NULL, *tmp;
433     unsigned int epoch;
434     int numpkts = sk_MEMPACKET_num(ctx->pkts);
435 
436     if (numpkts <= 0)
437         return 0;
438 
439     /*
440      * If there are multiple packets we only look in the last one. This should
441      * always be the one where any epoch change occurs.
442      */
443     thispkt = sk_MEMPACKET_value(ctx->pkts, numpkts - 1);
444     if (thispkt == NULL)
445         return 0;
446 
447     for (rem = thispkt->len, rec = thispkt->data; rem > 0; rem -= len, rec += len) {
448         if (rem < DTLS1_RT_HEADER_LENGTH)
449             return 0;
450         epoch = (rec[EPOCH_HI] << 8) | rec[EPOCH_LO];
451         len = ((rec[RECORD_LEN_HI] << 8) | rec[RECORD_LEN_LO])
452                 + DTLS1_RT_HEADER_LENGTH;
453         if (rem < len)
454             return 0;
455 
456         /* Assumes the epoch change does not happen on the first record */
457         if (epoch != ctx->epoch) {
458             if (prevrec == NULL)
459                 return 0;
460 
461             /*
462              * We found 2 records with different epochs. Take a copy of the
463              * earlier record
464              */
465             tmp = OPENSSL_malloc(prevlen);
466             if (tmp == NULL)
467                 return 0;
468 
469             memcpy(tmp, prevrec, prevlen);
470             /*
471              * Move everything from this record onwards, including any trailing
472              * records, and overwrite the earlier record
473              */
474             memmove(prevrec, rec, rem);
475             thispkt->len -= prevlen;
476             pktnum = thispkt->num;
477 
478             /*
479              * Create a new packet for the earlier record that we took out and
480              * add it to the end of the packet list.
481              */
482             thispkt = OPENSSL_malloc(sizeof(*thispkt));
483             if (thispkt == NULL) {
484                 OPENSSL_free(tmp);
485                 return 0;
486             }
487             thispkt->type = INJECT_PACKET;
488             thispkt->data = tmp;
489             thispkt->len = prevlen;
490             thispkt->num = pktnum + 1;
491             if (sk_MEMPACKET_insert(ctx->pkts, thispkt, numpkts) <= 0) {
492                 OPENSSL_free(tmp);
493                 OPENSSL_free(thispkt);
494                 return 0;
495             }
496 
497             return 1;
498         }
499         prevrec = rec;
500         prevlen = len;
501     }
502 
503     return 0;
504 }
505 
506 /* Move packet from position s to position d in the list (d < s) */
507 int mempacket_move_packet(BIO *bio, int d, int s)
508 {
509     MEMPACKET_TEST_CTX *ctx = BIO_get_data(bio);
510     MEMPACKET *thispkt;
511     int numpkts = sk_MEMPACKET_num(ctx->pkts);
512     int i;
513 
514     if (d >= s)
515         return 0;
516 
517     /* We need at least s + 1 packets to be able to swap them */
518     if (numpkts <= s)
519         return 0;
520 
521     /* Get the packet at position s */
522     thispkt = sk_MEMPACKET_value(ctx->pkts, s);
523     if (thispkt == NULL)
524         return 0;
525 
526     /* Remove and re-add it */
527     if (sk_MEMPACKET_delete(ctx->pkts, s) != thispkt)
528         return 0;
529 
530     thispkt->num -= (s - d);
531     if (sk_MEMPACKET_insert(ctx->pkts, thispkt, d) <= 0)
532         return 0;
533 
534     /* Increment the packet numbers for moved packets */
535     for (i = d + 1; i <= s; i++) {
536         thispkt = sk_MEMPACKET_value(ctx->pkts, i);
537         thispkt->num++;
538     }
539     return 1;
540 }
541 
542 int mempacket_test_inject(BIO *bio, const char *in, int inl, int pktnum,
543                           int type)
544 {
545     MEMPACKET_TEST_CTX *ctx = BIO_get_data(bio);
546     MEMPACKET *thispkt = NULL, *looppkt, *nextpkt, *allpkts[3];
547     int i, duprec;
548     const unsigned char *inu = (const unsigned char *)in;
549     size_t len = ((inu[RECORD_LEN_HI] << 8) | inu[RECORD_LEN_LO])
550                  + DTLS1_RT_HEADER_LENGTH;
551 
552     if (ctx == NULL)
553         return -1;
554 
555     if ((size_t)inl < len)
556         return -1;
557 
558     if ((size_t)inl == len)
559         duprec = 0;
560     else
561         duprec = ctx->duprec > 0;
562 
563     /* We don't support arbitrary injection when duplicating records */
564     if (duprec && pktnum != -1)
565         return -1;
566 
567     /* We only allow injection before we've started writing any data */
568     if (pktnum >= 0) {
569         if (ctx->noinject)
570             return -1;
571         ctx->injected  = 1;
572     } else {
573         ctx->noinject = 1;
574     }
575 
576     for (i = 0; i < (duprec ? 3 : 1); i++) {
577         if (!TEST_ptr(allpkts[i] = OPENSSL_malloc(sizeof(*thispkt))))
578             goto err;
579         thispkt = allpkts[i];
580 
581         if (!TEST_ptr(thispkt->data = OPENSSL_malloc(inl)))
582             goto err;
583         /*
584          * If we are duplicating the packet, we duplicate it three times. The
585          * first two times we drop the first record if there are more than one.
586          * In this way we know that libssl will not be able to make progress
587          * until it receives the last packet, and hence will be forced to
588          * buffer these records.
589          */
590         if (duprec && i != 2) {
591             memcpy(thispkt->data, in + len, inl - len);
592             thispkt->len = inl - len;
593         } else {
594             memcpy(thispkt->data, in, inl);
595             thispkt->len = inl;
596         }
597         thispkt->num = (pktnum >= 0) ? (unsigned int)pktnum : ctx->lastpkt + i;
598         thispkt->type = type;
599     }
600 
601     for (i = 0; i < sk_MEMPACKET_num(ctx->pkts); i++) {
602         if (!TEST_ptr(looppkt = sk_MEMPACKET_value(ctx->pkts, i)))
603             goto err;
604         /* Check if we found the right place to insert this packet */
605         if (looppkt->num > thispkt->num) {
606             if (sk_MEMPACKET_insert(ctx->pkts, thispkt, i) == 0)
607                 goto err;
608             /* If we're doing up front injection then we're done */
609             if (pktnum >= 0)
610                 return inl;
611             /*
612              * We need to do some accounting on lastpkt. We increment it first,
613              * but it might now equal the value of injected packets, so we need
614              * to skip over those
615              */
616             ctx->lastpkt++;
617             do {
618                 i++;
619                 nextpkt = sk_MEMPACKET_value(ctx->pkts, i);
620                 if (nextpkt != NULL && nextpkt->num == ctx->lastpkt)
621                     ctx->lastpkt++;
622                 else
623                     return inl;
624             } while(1);
625         } else if (looppkt->num == thispkt->num) {
626             if (!ctx->noinject) {
627                 /* We injected two packets with the same packet number! */
628                 goto err;
629             }
630             ctx->lastpkt++;
631             thispkt->num++;
632         }
633     }
634     /*
635      * We didn't find any packets with a packet number equal to or greater than
636      * this one, so we just add it onto the end
637      */
638     for (i = 0; i < (duprec ? 3 : 1); i++) {
639         thispkt = allpkts[i];
640         if (!sk_MEMPACKET_push(ctx->pkts, thispkt))
641             goto err;
642 
643         if (pktnum < 0)
644             ctx->lastpkt++;
645     }
646 
647     return inl;
648 
649  err:
650     for (i = 0; i < (ctx->duprec > 0 ? 3 : 1); i++)
651         mempacket_free(allpkts[i]);
652     return -1;
653 }
654 
655 static int mempacket_test_write(BIO *bio, const char *in, int inl)
656 {
657     return mempacket_test_inject(bio, in, inl, -1, STANDARD_PACKET);
658 }
659 
660 static long mempacket_test_ctrl(BIO *bio, int cmd, long num, void *ptr)
661 {
662     long ret = 1;
663     MEMPACKET_TEST_CTX *ctx = BIO_get_data(bio);
664     MEMPACKET *thispkt;
665 
666     switch (cmd) {
667     case BIO_CTRL_EOF:
668         ret = (long)(sk_MEMPACKET_num(ctx->pkts) == 0);
669         break;
670     case BIO_CTRL_GET_CLOSE:
671         ret = BIO_get_shutdown(bio);
672         break;
673     case BIO_CTRL_SET_CLOSE:
674         BIO_set_shutdown(bio, (int)num);
675         break;
676     case BIO_CTRL_WPENDING:
677         ret = 0L;
678         break;
679     case BIO_CTRL_PENDING:
680         thispkt = sk_MEMPACKET_value(ctx->pkts, 0);
681         if (thispkt == NULL)
682             ret = 0;
683         else
684             ret = thispkt->len;
685         break;
686     case BIO_CTRL_FLUSH:
687         ret = 1;
688         break;
689     case MEMPACKET_CTRL_SET_DROP_EPOCH:
690         ctx->dropepoch = (unsigned int)num;
691         break;
692     case MEMPACKET_CTRL_SET_DROP_REC:
693         ctx->droprec = (int)num;
694         break;
695     case MEMPACKET_CTRL_GET_DROP_REC:
696         ret = ctx->droprec;
697         break;
698     case MEMPACKET_CTRL_SET_DUPLICATE_REC:
699         ctx->duprec = (int)num;
700         break;
701     case BIO_CTRL_RESET:
702     case BIO_CTRL_DUP:
703     case BIO_CTRL_PUSH:
704     case BIO_CTRL_POP:
705     default:
706         ret = 0;
707         break;
708     }
709     return ret;
710 }
711 
712 static int mempacket_test_gets(BIO *bio, char *buf, int size)
713 {
714     /* We don't support this - not needed anyway */
715     return -1;
716 }
717 
718 static int mempacket_test_puts(BIO *bio, const char *str)
719 {
720     return mempacket_test_write(bio, str, strlen(str));
721 }
722 
723 static int always_retry_new(BIO *bi);
724 static int always_retry_free(BIO *a);
725 static int always_retry_read(BIO *b, char *out, int outl);
726 static int always_retry_write(BIO *b, const char *in, int inl);
727 static long always_retry_ctrl(BIO *b, int cmd, long num, void *ptr);
728 static int always_retry_gets(BIO *bp, char *buf, int size);
729 static int always_retry_puts(BIO *bp, const char *str);
730 
731 const BIO_METHOD *bio_s_always_retry(void)
732 {
733     if (meth_always_retry == NULL) {
734         if (!TEST_ptr(meth_always_retry = BIO_meth_new(BIO_TYPE_ALWAYS_RETRY,
735                                                        "Always Retry"))
736             || !TEST_true(BIO_meth_set_write(meth_always_retry,
737                                              always_retry_write))
738             || !TEST_true(BIO_meth_set_read(meth_always_retry,
739                                             always_retry_read))
740             || !TEST_true(BIO_meth_set_puts(meth_always_retry,
741                                             always_retry_puts))
742             || !TEST_true(BIO_meth_set_gets(meth_always_retry,
743                                             always_retry_gets))
744             || !TEST_true(BIO_meth_set_ctrl(meth_always_retry,
745                                             always_retry_ctrl))
746             || !TEST_true(BIO_meth_set_create(meth_always_retry,
747                                               always_retry_new))
748             || !TEST_true(BIO_meth_set_destroy(meth_always_retry,
749                                                always_retry_free)))
750             return NULL;
751     }
752     return meth_always_retry;
753 }
754 
755 void bio_s_always_retry_free(void)
756 {
757     BIO_meth_free(meth_always_retry);
758 }
759 
760 static int always_retry_new(BIO *bio)
761 {
762     BIO_set_init(bio, 1);
763     return 1;
764 }
765 
766 static int always_retry_free(BIO *bio)
767 {
768     BIO_set_data(bio, NULL);
769     BIO_set_init(bio, 0);
770     return 1;
771 }
772 
773 void set_always_retry_err_val(int err)
774 {
775     retry_err = err;
776 }
777 
778 static int always_retry_read(BIO *bio, char *out, int outl)
779 {
780     BIO_set_retry_read(bio);
781     return retry_err;
782 }
783 
784 static int always_retry_write(BIO *bio, const char *in, int inl)
785 {
786     BIO_set_retry_write(bio);
787     return retry_err;
788 }
789 
790 static long always_retry_ctrl(BIO *bio, int cmd, long num, void *ptr)
791 {
792     long ret = 1;
793 
794     switch (cmd) {
795     case BIO_CTRL_FLUSH:
796         BIO_set_retry_write(bio);
797         /* fall through */
798     case BIO_CTRL_EOF:
799     case BIO_CTRL_RESET:
800     case BIO_CTRL_DUP:
801     case BIO_CTRL_PUSH:
802     case BIO_CTRL_POP:
803     default:
804         ret = 0;
805         break;
806     }
807     return ret;
808 }
809 
810 static int always_retry_gets(BIO *bio, char *buf, int size)
811 {
812     BIO_set_retry_read(bio);
813     return retry_err;
814 }
815 
816 static int always_retry_puts(BIO *bio, const char *str)
817 {
818     BIO_set_retry_write(bio);
819     return retry_err;
820 }
821 
822 int create_ssl_ctx_pair(OSSL_LIB_CTX *libctx, const SSL_METHOD *sm,
823                         const SSL_METHOD *cm, int min_proto_version,
824                         int max_proto_version, SSL_CTX **sctx, SSL_CTX **cctx,
825                         char *certfile, char *privkeyfile)
826 {
827     SSL_CTX *serverctx = NULL;
828     SSL_CTX *clientctx = NULL;
829 
830     if (sctx != NULL) {
831         if (*sctx != NULL)
832             serverctx = *sctx;
833         else if (!TEST_ptr(serverctx = SSL_CTX_new_ex(libctx, NULL, sm))
834             || !TEST_true(SSL_CTX_set_options(serverctx,
835                                               SSL_OP_ALLOW_CLIENT_RENEGOTIATION)))
836             goto err;
837     }
838 
839     if (cctx != NULL) {
840         if (*cctx != NULL)
841             clientctx = *cctx;
842         else if (!TEST_ptr(clientctx = SSL_CTX_new_ex(libctx, NULL, cm)))
843             goto err;
844     }
845 
846 #if !defined(OPENSSL_NO_TLS1_3) \
847     && defined(OPENSSL_NO_EC) \
848     && defined(OPENSSL_NO_DH)
849     /*
850      * There are no usable built-in TLSv1.3 groups if ec and dh are both
851      * disabled
852      */
853     if (max_proto_version == 0
854             && (sm == TLS_server_method() || cm == TLS_client_method()))
855         max_proto_version = TLS1_2_VERSION;
856 #endif
857 
858     if (serverctx != NULL
859             && ((min_proto_version > 0
860                  && !TEST_true(SSL_CTX_set_min_proto_version(serverctx,
861                                                             min_proto_version)))
862                 || (max_proto_version > 0
863                     && !TEST_true(SSL_CTX_set_max_proto_version(serverctx,
864                                                                 max_proto_version)))))
865         goto err;
866     if (clientctx != NULL
867         && ((min_proto_version > 0
868              && !TEST_true(SSL_CTX_set_min_proto_version(clientctx,
869                                                          min_proto_version)))
870             || (max_proto_version > 0
871                 && !TEST_true(SSL_CTX_set_max_proto_version(clientctx,
872                                                             max_proto_version)))))
873         goto err;
874 
875     if (serverctx != NULL && certfile != NULL && privkeyfile != NULL) {
876         if (!TEST_int_eq(SSL_CTX_use_certificate_file(serverctx, certfile,
877                                                       SSL_FILETYPE_PEM), 1)
878                 || !TEST_int_eq(SSL_CTX_use_PrivateKey_file(serverctx,
879                                                             privkeyfile,
880                                                             SSL_FILETYPE_PEM), 1)
881                 || !TEST_int_eq(SSL_CTX_check_private_key(serverctx), 1))
882             goto err;
883     }
884 
885     if (sctx != NULL)
886         *sctx = serverctx;
887     if (cctx != NULL)
888         *cctx = clientctx;
889     return 1;
890 
891  err:
892     if (sctx != NULL && *sctx == NULL)
893         SSL_CTX_free(serverctx);
894     if (cctx != NULL && *cctx == NULL)
895         SSL_CTX_free(clientctx);
896     return 0;
897 }
898 
899 #define MAXLOOPS    1000000
900 
901 #if !defined(OPENSSL_NO_KTLS) && !defined(OPENSSL_NO_SOCK)
902 static int set_nb(int fd)
903 {
904     int flags;
905 
906     flags = fcntl(fd,F_GETFL,0);
907     if (flags == -1)
908         return flags;
909     flags = fcntl(fd, F_SETFL, flags | O_NONBLOCK);
910     return flags;
911 }
912 
913 int create_test_sockets(int *cfdp, int *sfdp)
914 {
915     struct sockaddr_in sin;
916     const char *host = "127.0.0.1";
917     int cfd_connected = 0, ret = 0;
918     socklen_t slen = sizeof(sin);
919     int afd = -1, cfd = -1, sfd = -1;
920 
921     memset ((char *) &sin, 0, sizeof(sin));
922     sin.sin_family = AF_INET;
923     sin.sin_addr.s_addr = inet_addr(host);
924 
925     afd = socket(AF_INET, SOCK_STREAM, 0);
926     if (afd < 0)
927         return 0;
928 
929     if (bind(afd, (struct sockaddr*)&sin, sizeof(sin)) < 0)
930         goto out;
931 
932     if (getsockname(afd, (struct sockaddr*)&sin, &slen) < 0)
933         goto out;
934 
935     if (listen(afd, 1) < 0)
936         goto out;
937 
938     cfd = socket(AF_INET, SOCK_STREAM, 0);
939     if (cfd < 0)
940         goto out;
941 
942     if (set_nb(afd) == -1)
943         goto out;
944 
945     while (sfd == -1 || !cfd_connected ) {
946         sfd = accept(afd, NULL, 0);
947         if (sfd == -1 && errno != EAGAIN)
948             goto out;
949 
950         if (!cfd_connected && connect(cfd, (struct sockaddr*)&sin, sizeof(sin)) < 0)
951             goto out;
952         else
953             cfd_connected = 1;
954     }
955 
956     if (set_nb(cfd) == -1 || set_nb(sfd) == -1)
957         goto out;
958     ret = 1;
959     *cfdp = cfd;
960     *sfdp = sfd;
961     goto success;
962 
963 out:
964     if (cfd != -1)
965         close(cfd);
966     if (sfd != -1)
967         close(sfd);
968 success:
969     if (afd != -1)
970         close(afd);
971     return ret;
972 }
973 
974 int create_ssl_objects2(SSL_CTX *serverctx, SSL_CTX *clientctx, SSL **sssl,
975                           SSL **cssl, int sfd, int cfd)
976 {
977     SSL *serverssl = NULL, *clientssl = NULL;
978     BIO *s_to_c_bio = NULL, *c_to_s_bio = NULL;
979 
980     if (*sssl != NULL)
981         serverssl = *sssl;
982     else if (!TEST_ptr(serverssl = SSL_new(serverctx)))
983         goto error;
984     if (*cssl != NULL)
985         clientssl = *cssl;
986     else if (!TEST_ptr(clientssl = SSL_new(clientctx)))
987         goto error;
988 
989     if (!TEST_ptr(s_to_c_bio = BIO_new_socket(sfd, BIO_NOCLOSE))
990             || !TEST_ptr(c_to_s_bio = BIO_new_socket(cfd, BIO_NOCLOSE)))
991         goto error;
992 
993     SSL_set_bio(clientssl, c_to_s_bio, c_to_s_bio);
994     SSL_set_bio(serverssl, s_to_c_bio, s_to_c_bio);
995     *sssl = serverssl;
996     *cssl = clientssl;
997     return 1;
998 
999  error:
1000     SSL_free(serverssl);
1001     SSL_free(clientssl);
1002     BIO_free(s_to_c_bio);
1003     BIO_free(c_to_s_bio);
1004     return 0;
1005 }
1006 #endif
1007 
1008 /*
1009  * NOTE: Transfers control of the BIOs - this function will free them on error
1010  */
1011 int create_ssl_objects(SSL_CTX *serverctx, SSL_CTX *clientctx, SSL **sssl,
1012                           SSL **cssl, BIO *s_to_c_fbio, BIO *c_to_s_fbio)
1013 {
1014     SSL *serverssl = NULL, *clientssl = NULL;
1015     BIO *s_to_c_bio = NULL, *c_to_s_bio = NULL;
1016 
1017     if (*sssl != NULL)
1018         serverssl = *sssl;
1019     else if (!TEST_ptr(serverssl = SSL_new(serverctx)))
1020         goto error;
1021     if (*cssl != NULL)
1022         clientssl = *cssl;
1023     else if (!TEST_ptr(clientssl = SSL_new(clientctx)))
1024         goto error;
1025 
1026     if (SSL_is_dtls(clientssl)) {
1027         if (!TEST_ptr(s_to_c_bio = BIO_new(bio_s_mempacket_test()))
1028                 || !TEST_ptr(c_to_s_bio = BIO_new(bio_s_mempacket_test())))
1029             goto error;
1030     } else {
1031         if (!TEST_ptr(s_to_c_bio = BIO_new(BIO_s_mem()))
1032                 || !TEST_ptr(c_to_s_bio = BIO_new(BIO_s_mem())))
1033             goto error;
1034     }
1035 
1036     if (s_to_c_fbio != NULL
1037             && !TEST_ptr(s_to_c_bio = BIO_push(s_to_c_fbio, s_to_c_bio)))
1038         goto error;
1039     if (c_to_s_fbio != NULL
1040             && !TEST_ptr(c_to_s_bio = BIO_push(c_to_s_fbio, c_to_s_bio)))
1041         goto error;
1042 
1043     /* Set Non-blocking IO behaviour */
1044     BIO_set_mem_eof_return(s_to_c_bio, -1);
1045     BIO_set_mem_eof_return(c_to_s_bio, -1);
1046 
1047     /* Up ref these as we are passing them to two SSL objects */
1048     SSL_set_bio(serverssl, c_to_s_bio, s_to_c_bio);
1049     BIO_up_ref(s_to_c_bio);
1050     BIO_up_ref(c_to_s_bio);
1051     SSL_set_bio(clientssl, s_to_c_bio, c_to_s_bio);
1052     *sssl = serverssl;
1053     *cssl = clientssl;
1054     return 1;
1055 
1056  error:
1057     SSL_free(serverssl);
1058     SSL_free(clientssl);
1059     BIO_free(s_to_c_bio);
1060     BIO_free(c_to_s_bio);
1061     BIO_free(s_to_c_fbio);
1062     BIO_free(c_to_s_fbio);
1063 
1064     return 0;
1065 }
1066 
1067 /*
1068  * Create an SSL connection, but does not read any post-handshake
1069  * NewSessionTicket messages.
1070  * If |read| is set and we're using DTLS then we will attempt to SSL_read on
1071  * the connection once we've completed one half of it, to ensure any retransmits
1072  * get triggered.
1073  * We stop the connection attempt (and return a failure value) if either peer
1074  * has SSL_get_error() return the value in the |want| parameter. The connection
1075  * attempt could be restarted by a subsequent call to this function.
1076  */
1077 int create_bare_ssl_connection(SSL *serverssl, SSL *clientssl, int want,
1078                                int read)
1079 {
1080     int retc = -1, rets = -1, err, abortctr = 0;
1081     int clienterr = 0, servererr = 0;
1082     int isdtls = SSL_is_dtls(serverssl);
1083 
1084     do {
1085         err = SSL_ERROR_WANT_WRITE;
1086         while (!clienterr && retc <= 0 && err == SSL_ERROR_WANT_WRITE) {
1087             retc = SSL_connect(clientssl);
1088             if (retc <= 0)
1089                 err = SSL_get_error(clientssl, retc);
1090         }
1091 
1092         if (!clienterr && retc <= 0 && err != SSL_ERROR_WANT_READ) {
1093             TEST_info("SSL_connect() failed %d, %d", retc, err);
1094             if (want != SSL_ERROR_SSL)
1095                 TEST_openssl_errors();
1096             clienterr = 1;
1097         }
1098         if (want != SSL_ERROR_NONE && err == want)
1099             return 0;
1100 
1101         err = SSL_ERROR_WANT_WRITE;
1102         while (!servererr && rets <= 0 && err == SSL_ERROR_WANT_WRITE) {
1103             rets = SSL_accept(serverssl);
1104             if (rets <= 0)
1105                 err = SSL_get_error(serverssl, rets);
1106         }
1107 
1108         if (!servererr && rets <= 0
1109                 && err != SSL_ERROR_WANT_READ
1110                 && err != SSL_ERROR_WANT_X509_LOOKUP) {
1111             TEST_info("SSL_accept() failed %d, %d", rets, err);
1112             if (want != SSL_ERROR_SSL)
1113                 TEST_openssl_errors();
1114             servererr = 1;
1115         }
1116         if (want != SSL_ERROR_NONE && err == want)
1117             return 0;
1118         if (clienterr && servererr)
1119             return 0;
1120         if (isdtls && read) {
1121             unsigned char buf[20];
1122 
1123             /* Trigger any retransmits that may be appropriate */
1124             if (rets > 0 && retc <= 0) {
1125                 if (SSL_read(serverssl, buf, sizeof(buf)) > 0) {
1126                     /* We don't expect this to succeed! */
1127                     TEST_info("Unexpected SSL_read() success!");
1128                     return 0;
1129                 }
1130             }
1131             if (retc > 0 && rets <= 0) {
1132                 if (SSL_read(clientssl, buf, sizeof(buf)) > 0) {
1133                     /* We don't expect this to succeed! */
1134                     TEST_info("Unexpected SSL_read() success!");
1135                     return 0;
1136                 }
1137             }
1138         }
1139         if (++abortctr == MAXLOOPS) {
1140             TEST_info("No progress made");
1141             return 0;
1142         }
1143         if (isdtls && abortctr <= 50 && (abortctr % 10) == 0) {
1144             /*
1145              * It looks like we're just spinning. Pause for a short period to
1146              * give the DTLS timer a chance to do something. We only do this for
1147              * the first few times to prevent hangs.
1148              */
1149             ossl_sleep(50);
1150         }
1151     } while (retc <=0 || rets <= 0);
1152 
1153     return 1;
1154 }
1155 
1156 /*
1157  * Create an SSL connection including any post handshake NewSessionTicket
1158  * messages.
1159  */
1160 int create_ssl_connection(SSL *serverssl, SSL *clientssl, int want)
1161 {
1162     int i;
1163     unsigned char buf;
1164     size_t readbytes;
1165 
1166     if (!create_bare_ssl_connection(serverssl, clientssl, want, 1))
1167         return 0;
1168 
1169     /*
1170      * We attempt to read some data on the client side which we expect to fail.
1171      * This will ensure we have received the NewSessionTicket in TLSv1.3 where
1172      * appropriate. We do this twice because there are 2 NewSessionTickets.
1173      */
1174     for (i = 0; i < 2; i++) {
1175         if (SSL_read_ex(clientssl, &buf, sizeof(buf), &readbytes) > 0) {
1176             if (!TEST_ulong_eq(readbytes, 0))
1177                 return 0;
1178         } else if (!TEST_int_eq(SSL_get_error(clientssl, 0),
1179                                 SSL_ERROR_WANT_READ)) {
1180             return 0;
1181         }
1182     }
1183 
1184     return 1;
1185 }
1186 
1187 void shutdown_ssl_connection(SSL *serverssl, SSL *clientssl)
1188 {
1189     SSL_shutdown(clientssl);
1190     SSL_shutdown(serverssl);
1191     SSL_free(serverssl);
1192     SSL_free(clientssl);
1193 }
1194 
1195 ENGINE *load_dasync(void)
1196 {
1197 #if !defined(OPENSSL_NO_TLS1_2) && !defined(OPENSSL_NO_DYNAMIC_ENGINE)
1198     ENGINE *e;
1199 
1200     if (!TEST_ptr(e = ENGINE_by_id("dasync")))
1201         return NULL;
1202 
1203     if (!TEST_true(ENGINE_init(e))) {
1204         ENGINE_free(e);
1205         return NULL;
1206     }
1207 
1208     if (!TEST_true(ENGINE_register_ciphers(e))) {
1209         ENGINE_free(e);
1210         return NULL;
1211     }
1212 
1213     return e;
1214 #else
1215     return NULL;
1216 #endif
1217 }
1218