xref: /linux/net/mctp/test/route-test.c (revision 46ee16462fed5c3a065b603d677a9a36462dab7d)
1 // SPDX-License-Identifier: GPL-2.0
2 
3 #include <kunit/test.h>
4 
5 /* keep clangd happy when compiled outside of the route.c include */
6 #include <net/mctp.h>
7 #include <net/mctpdevice.h>
8 
9 #include "utils.h"
10 
11 struct mctp_test_route {
12 	struct mctp_route	rt;
13 };
14 
15 static const unsigned int test_pktqueue_magic = 0x5f713aef;
16 
17 struct mctp_test_pktqueue {
18 	unsigned int magic;
19 	struct sk_buff_head pkts;
20 };
21 
22 static void mctp_test_pktqueue_init(struct mctp_test_pktqueue *tpq)
23 {
24 	tpq->magic = test_pktqueue_magic;
25 	skb_queue_head_init(&tpq->pkts);
26 }
27 
28 static int mctp_test_dst_output(struct mctp_dst *dst, struct sk_buff *skb)
29 {
30 	struct kunit *test = current->kunit_test;
31 	struct mctp_test_pktqueue *tpq = test->priv;
32 
33 	KUNIT_ASSERT_EQ(test, tpq->magic, test_pktqueue_magic);
34 
35 	skb_queue_tail(&tpq->pkts, skb);
36 
37 	return 0;
38 }
39 
40 /* local version of mctp_route_alloc() */
41 static struct mctp_test_route *mctp_route_test_alloc(void)
42 {
43 	struct mctp_test_route *rt;
44 
45 	rt = kzalloc(sizeof(*rt), GFP_KERNEL);
46 	if (!rt)
47 		return NULL;
48 
49 	INIT_LIST_HEAD(&rt->rt.list);
50 	refcount_set(&rt->rt.refs, 1);
51 	rt->rt.output = mctp_test_dst_output;
52 
53 	return rt;
54 }
55 
56 static struct mctp_test_route *mctp_test_create_route(struct net *net,
57 						      struct mctp_dev *dev,
58 						      mctp_eid_t eid,
59 						      unsigned int mtu)
60 {
61 	struct mctp_test_route *rt;
62 
63 	rt = mctp_route_test_alloc();
64 	if (!rt)
65 		return NULL;
66 
67 	rt->rt.min = eid;
68 	rt->rt.max = eid;
69 	rt->rt.mtu = mtu;
70 	rt->rt.type = RTN_UNSPEC;
71 	if (dev)
72 		mctp_dev_hold(dev);
73 	rt->rt.dev = dev;
74 
75 	list_add_rcu(&rt->rt.list, &net->mctp.routes);
76 
77 	return rt;
78 }
79 
80 /* Convenience function for our test dst; release with mctp_test_dst_release()
81  */
82 static void mctp_test_dst_setup(struct kunit *test, struct mctp_dst *dst,
83 				struct mctp_test_dev *dev,
84 				struct mctp_test_pktqueue *tpq,
85 				unsigned int mtu)
86 {
87 	KUNIT_EXPECT_NOT_ERR_OR_NULL(test, dev);
88 
89 	memset(dst, 0, sizeof(*dst));
90 
91 	dst->dev = dev->mdev;
92 	__mctp_dev_get(dst->dev->dev);
93 	dst->mtu = mtu;
94 	dst->output = mctp_test_dst_output;
95 	mctp_test_pktqueue_init(tpq);
96 	test->priv = tpq;
97 }
98 
99 static void mctp_test_dst_release(struct mctp_dst *dst,
100 				  struct mctp_test_pktqueue *tpq)
101 {
102 	mctp_dst_release(dst);
103 	skb_queue_purge(&tpq->pkts);
104 }
105 
106 static void mctp_test_route_destroy(struct kunit *test,
107 				    struct mctp_test_route *rt)
108 {
109 	unsigned int refs;
110 
111 	rtnl_lock();
112 	list_del_rcu(&rt->rt.list);
113 	rtnl_unlock();
114 
115 	if (rt->rt.dev)
116 		mctp_dev_put(rt->rt.dev);
117 
118 	refs = refcount_read(&rt->rt.refs);
119 	KUNIT_ASSERT_EQ_MSG(test, refs, 1, "route ref imbalance");
120 
121 	kfree_rcu(&rt->rt, rcu);
122 }
123 
124 static void mctp_test_skb_set_dev(struct sk_buff *skb,
125 				  struct mctp_test_dev *dev)
126 {
127 	struct mctp_skb_cb *cb;
128 
129 	cb = mctp_cb(skb);
130 	cb->net = READ_ONCE(dev->mdev->net);
131 	skb->dev = dev->ndev;
132 }
133 
134 static struct sk_buff *mctp_test_create_skb(const struct mctp_hdr *hdr,
135 					    unsigned int data_len)
136 {
137 	size_t hdr_len = sizeof(*hdr);
138 	struct sk_buff *skb;
139 	unsigned int i;
140 	u8 *buf;
141 
142 	skb = alloc_skb(hdr_len + data_len, GFP_KERNEL);
143 	if (!skb)
144 		return NULL;
145 
146 	__mctp_cb(skb);
147 	memcpy(skb_put(skb, hdr_len), hdr, hdr_len);
148 
149 	buf = skb_put(skb, data_len);
150 	for (i = 0; i < data_len; i++)
151 		buf[i] = i & 0xff;
152 
153 	return skb;
154 }
155 
156 static struct sk_buff *__mctp_test_create_skb_data(const struct mctp_hdr *hdr,
157 						   const void *data,
158 						   size_t data_len)
159 {
160 	size_t hdr_len = sizeof(*hdr);
161 	struct sk_buff *skb;
162 
163 	skb = alloc_skb(hdr_len + data_len, GFP_KERNEL);
164 	if (!skb)
165 		return NULL;
166 
167 	__mctp_cb(skb);
168 	memcpy(skb_put(skb, hdr_len), hdr, hdr_len);
169 	memcpy(skb_put(skb, data_len), data, data_len);
170 
171 	return skb;
172 }
173 
174 #define mctp_test_create_skb_data(h, d) \
175 	__mctp_test_create_skb_data(h, d, sizeof(*d))
176 
177 struct mctp_frag_test {
178 	unsigned int mtu;
179 	unsigned int msgsize;
180 	unsigned int n_frags;
181 };
182 
183 static void mctp_test_fragment(struct kunit *test)
184 {
185 	const struct mctp_frag_test *params;
186 	struct mctp_test_pktqueue tpq;
187 	int rc, i, n, mtu, msgsize;
188 	struct mctp_test_dev *dev;
189 	struct mctp_dst dst;
190 	struct sk_buff *skb;
191 	struct mctp_hdr hdr;
192 	u8 seq;
193 
194 	params = test->param_value;
195 	mtu = params->mtu;
196 	msgsize = params->msgsize;
197 
198 	hdr.ver = 1;
199 	hdr.src = 8;
200 	hdr.dest = 10;
201 	hdr.flags_seq_tag = MCTP_HDR_FLAG_TO;
202 
203 	skb = mctp_test_create_skb(&hdr, msgsize);
204 	KUNIT_ASSERT_TRUE(test, skb);
205 
206 	dev = mctp_test_create_dev();
207 	KUNIT_ASSERT_NOT_ERR_OR_NULL(test, dev);
208 
209 	mctp_test_dst_setup(test, &dst, dev, &tpq, mtu);
210 
211 	rc = mctp_do_fragment_route(&dst, skb, mtu, MCTP_TAG_OWNER);
212 	KUNIT_EXPECT_FALSE(test, rc);
213 
214 	n = tpq.pkts.qlen;
215 
216 	KUNIT_EXPECT_EQ(test, n, params->n_frags);
217 
218 	for (i = 0;; i++) {
219 		struct mctp_hdr *hdr2;
220 		struct sk_buff *skb2;
221 		u8 tag_mask, seq2;
222 		bool first, last;
223 
224 		first = i == 0;
225 		last = i == (n - 1);
226 
227 		skb2 = skb_dequeue(&tpq.pkts);
228 
229 		if (!skb2)
230 			break;
231 
232 		hdr2 = mctp_hdr(skb2);
233 
234 		tag_mask = MCTP_HDR_TAG_MASK | MCTP_HDR_FLAG_TO;
235 
236 		KUNIT_EXPECT_EQ(test, hdr2->ver, hdr.ver);
237 		KUNIT_EXPECT_EQ(test, hdr2->src, hdr.src);
238 		KUNIT_EXPECT_EQ(test, hdr2->dest, hdr.dest);
239 		KUNIT_EXPECT_EQ(test, hdr2->flags_seq_tag & tag_mask,
240 				hdr.flags_seq_tag & tag_mask);
241 
242 		KUNIT_EXPECT_EQ(test,
243 				!!(hdr2->flags_seq_tag & MCTP_HDR_FLAG_SOM), first);
244 		KUNIT_EXPECT_EQ(test,
245 				!!(hdr2->flags_seq_tag & MCTP_HDR_FLAG_EOM), last);
246 
247 		seq2 = (hdr2->flags_seq_tag >> MCTP_HDR_SEQ_SHIFT) &
248 			MCTP_HDR_SEQ_MASK;
249 
250 		if (first) {
251 			seq = seq2;
252 		} else {
253 			seq++;
254 			KUNIT_EXPECT_EQ(test, seq2, seq & MCTP_HDR_SEQ_MASK);
255 		}
256 
257 		if (!last)
258 			KUNIT_EXPECT_EQ(test, skb2->len, mtu);
259 		else
260 			KUNIT_EXPECT_LE(test, skb2->len, mtu);
261 
262 		kfree_skb(skb2);
263 	}
264 
265 	mctp_test_dst_release(&dst, &tpq);
266 	mctp_test_destroy_dev(dev);
267 }
268 
269 static const struct mctp_frag_test mctp_frag_tests[] = {
270 	{.mtu = 68, .msgsize = 63, .n_frags = 1},
271 	{.mtu = 68, .msgsize = 64, .n_frags = 1},
272 	{.mtu = 68, .msgsize = 65, .n_frags = 2},
273 	{.mtu = 68, .msgsize = 66, .n_frags = 2},
274 	{.mtu = 68, .msgsize = 127, .n_frags = 2},
275 	{.mtu = 68, .msgsize = 128, .n_frags = 2},
276 	{.mtu = 68, .msgsize = 129, .n_frags = 3},
277 	{.mtu = 68, .msgsize = 130, .n_frags = 3},
278 };
279 
280 static void mctp_frag_test_to_desc(const struct mctp_frag_test *t, char *desc)
281 {
282 	sprintf(desc, "mtu %d len %d -> %d frags",
283 		t->msgsize, t->mtu, t->n_frags);
284 }
285 
286 KUNIT_ARRAY_PARAM(mctp_frag, mctp_frag_tests, mctp_frag_test_to_desc);
287 
288 struct mctp_rx_input_test {
289 	struct mctp_hdr hdr;
290 	bool input;
291 };
292 
293 static void mctp_test_rx_input(struct kunit *test)
294 {
295 	const struct mctp_rx_input_test *params;
296 	struct mctp_test_pktqueue tpq;
297 	struct mctp_test_route *rt;
298 	struct mctp_test_dev *dev;
299 	struct sk_buff *skb;
300 
301 	params = test->param_value;
302 	test->priv = &tpq;
303 
304 	dev = mctp_test_create_dev();
305 	KUNIT_ASSERT_NOT_ERR_OR_NULL(test, dev);
306 
307 	rt = mctp_test_create_route(&init_net, dev->mdev, 8, 68);
308 	KUNIT_ASSERT_NOT_ERR_OR_NULL(test, rt);
309 
310 	skb = mctp_test_create_skb(&params->hdr, 1);
311 	KUNIT_ASSERT_NOT_ERR_OR_NULL(test, skb);
312 
313 	mctp_test_pktqueue_init(&tpq);
314 
315 	mctp_pkttype_receive(skb, dev->ndev, &mctp_packet_type, NULL);
316 
317 	KUNIT_EXPECT_EQ(test, !!tpq.pkts.qlen, params->input);
318 
319 	skb_queue_purge(&tpq.pkts);
320 	mctp_test_route_destroy(test, rt);
321 	mctp_test_destroy_dev(dev);
322 }
323 
324 #define RX_HDR(_ver, _src, _dest, _fst) \
325 	{ .ver = _ver, .src = _src, .dest = _dest, .flags_seq_tag = _fst }
326 
327 /* we have a route for EID 8 only */
328 static const struct mctp_rx_input_test mctp_rx_input_tests[] = {
329 	{ .hdr = RX_HDR(1, 10, 8, 0), .input = true },
330 	{ .hdr = RX_HDR(1, 10, 9, 0), .input = false }, /* no input route */
331 	{ .hdr = RX_HDR(2, 10, 8, 0), .input = false }, /* invalid version */
332 };
333 
334 static void mctp_rx_input_test_to_desc(const struct mctp_rx_input_test *t,
335 				       char *desc)
336 {
337 	sprintf(desc, "{%x,%x,%x,%x}", t->hdr.ver, t->hdr.src, t->hdr.dest,
338 		t->hdr.flags_seq_tag);
339 }
340 
341 KUNIT_ARRAY_PARAM(mctp_rx_input, mctp_rx_input_tests,
342 		  mctp_rx_input_test_to_desc);
343 
344 /* set up a local dev, route on EID 8, and a socket listening on type 0 */
345 static void __mctp_route_test_init(struct kunit *test,
346 				   struct mctp_test_dev **devp,
347 				   struct mctp_dst *dst,
348 				   struct mctp_test_pktqueue *tpq,
349 				   struct socket **sockp,
350 				   unsigned int netid)
351 {
352 	struct sockaddr_mctp addr = {0};
353 	struct mctp_test_dev *dev;
354 	struct socket *sock;
355 	int rc;
356 
357 	dev = mctp_test_create_dev();
358 	KUNIT_ASSERT_NOT_ERR_OR_NULL(test, dev);
359 	if (netid != MCTP_NET_ANY)
360 		WRITE_ONCE(dev->mdev->net, netid);
361 
362 	mctp_test_dst_setup(test, dst, dev, tpq, 68);
363 
364 	rc = sock_create_kern(&init_net, AF_MCTP, SOCK_DGRAM, 0, &sock);
365 	KUNIT_ASSERT_EQ(test, rc, 0);
366 
367 	addr.smctp_family = AF_MCTP;
368 	addr.smctp_network = netid;
369 	addr.smctp_addr.s_addr = 8;
370 	addr.smctp_type = 0;
371 	rc = kernel_bind(sock, (struct sockaddr *)&addr, sizeof(addr));
372 	KUNIT_ASSERT_EQ(test, rc, 0);
373 
374 	*devp = dev;
375 	*sockp = sock;
376 }
377 
378 static void __mctp_route_test_fini(struct kunit *test,
379 				   struct mctp_test_dev *dev,
380 				   struct mctp_dst *dst,
381 				   struct mctp_test_pktqueue *tpq,
382 				   struct socket *sock)
383 {
384 	sock_release(sock);
385 	mctp_test_dst_release(dst, tpq);
386 	mctp_test_destroy_dev(dev);
387 }
388 
389 struct mctp_route_input_sk_test {
390 	struct mctp_hdr hdr;
391 	u8 type;
392 	bool deliver;
393 };
394 
395 static void mctp_test_route_input_sk(struct kunit *test)
396 {
397 	const struct mctp_route_input_sk_test *params;
398 	struct mctp_test_pktqueue tpq;
399 	struct sk_buff *skb, *skb2;
400 	struct mctp_test_dev *dev;
401 	struct mctp_dst dst;
402 	struct socket *sock;
403 	int rc;
404 
405 	params = test->param_value;
406 
407 	__mctp_route_test_init(test, &dev, &dst, &tpq, &sock, MCTP_NET_ANY);
408 
409 	skb = mctp_test_create_skb_data(&params->hdr, &params->type);
410 	KUNIT_ASSERT_NOT_ERR_OR_NULL(test, skb);
411 
412 	mctp_test_skb_set_dev(skb, dev);
413 	mctp_test_pktqueue_init(&tpq);
414 
415 	rc = mctp_dst_input(&dst, skb);
416 
417 	if (params->deliver) {
418 		KUNIT_EXPECT_EQ(test, rc, 0);
419 
420 		skb2 = skb_recv_datagram(sock->sk, MSG_DONTWAIT, &rc);
421 		KUNIT_EXPECT_NOT_ERR_OR_NULL(test, skb2);
422 		KUNIT_EXPECT_EQ(test, skb2->len, 1);
423 
424 		skb_free_datagram(sock->sk, skb2);
425 
426 	} else {
427 		KUNIT_EXPECT_NE(test, rc, 0);
428 		skb2 = skb_recv_datagram(sock->sk, MSG_DONTWAIT, &rc);
429 		KUNIT_EXPECT_NULL(test, skb2);
430 	}
431 
432 	__mctp_route_test_fini(test, dev, &dst, &tpq, sock);
433 }
434 
435 #define FL_S	(MCTP_HDR_FLAG_SOM)
436 #define FL_E	(MCTP_HDR_FLAG_EOM)
437 #define FL_TO	(MCTP_HDR_FLAG_TO)
438 #define FL_T(t)	((t) & MCTP_HDR_TAG_MASK)
439 
440 static const struct mctp_route_input_sk_test mctp_route_input_sk_tests[] = {
441 	{ .hdr = RX_HDR(1, 10, 8, FL_S | FL_E | FL_TO), .type = 0, .deliver = true },
442 	{ .hdr = RX_HDR(1, 10, 8, FL_S | FL_E | FL_TO), .type = 1, .deliver = false },
443 	{ .hdr = RX_HDR(1, 10, 8, FL_S | FL_E), .type = 0, .deliver = false },
444 	{ .hdr = RX_HDR(1, 10, 8, FL_E | FL_TO), .type = 0, .deliver = false },
445 	{ .hdr = RX_HDR(1, 10, 8, FL_TO), .type = 0, .deliver = false },
446 	{ .hdr = RX_HDR(1, 10, 8, 0), .type = 0, .deliver = false },
447 };
448 
449 static void mctp_route_input_sk_to_desc(const struct mctp_route_input_sk_test *t,
450 					char *desc)
451 {
452 	sprintf(desc, "{%x,%x,%x,%x} type %d", t->hdr.ver, t->hdr.src,
453 		t->hdr.dest, t->hdr.flags_seq_tag, t->type);
454 }
455 
456 KUNIT_ARRAY_PARAM(mctp_route_input_sk, mctp_route_input_sk_tests,
457 		  mctp_route_input_sk_to_desc);
458 
459 struct mctp_route_input_sk_reasm_test {
460 	const char *name;
461 	struct mctp_hdr hdrs[4];
462 	int n_hdrs;
463 	int rx_len;
464 };
465 
466 static void mctp_test_route_input_sk_reasm(struct kunit *test)
467 {
468 	const struct mctp_route_input_sk_reasm_test *params;
469 	struct mctp_test_pktqueue tpq;
470 	struct sk_buff *skb, *skb2;
471 	struct mctp_test_dev *dev;
472 	struct mctp_dst dst;
473 	struct socket *sock;
474 	int i, rc;
475 	u8 c;
476 
477 	params = test->param_value;
478 
479 	__mctp_route_test_init(test, &dev, &dst, &tpq, &sock, MCTP_NET_ANY);
480 
481 	for (i = 0; i < params->n_hdrs; i++) {
482 		c = i;
483 		skb = mctp_test_create_skb_data(&params->hdrs[i], &c);
484 		KUNIT_ASSERT_NOT_ERR_OR_NULL(test, skb);
485 
486 		mctp_test_skb_set_dev(skb, dev);
487 
488 		rc = mctp_dst_input(&dst, skb);
489 	}
490 
491 	skb2 = skb_recv_datagram(sock->sk, MSG_DONTWAIT, &rc);
492 
493 	if (params->rx_len) {
494 		KUNIT_EXPECT_NOT_ERR_OR_NULL(test, skb2);
495 		KUNIT_EXPECT_EQ(test, skb2->len, params->rx_len);
496 		skb_free_datagram(sock->sk, skb2);
497 
498 	} else {
499 		KUNIT_EXPECT_NULL(test, skb2);
500 	}
501 
502 	__mctp_route_test_fini(test, dev, &dst, &tpq, sock);
503 }
504 
505 #define RX_FRAG(f, s) RX_HDR(1, 10, 8, FL_TO | (f) | ((s) << MCTP_HDR_SEQ_SHIFT))
506 
507 static const struct mctp_route_input_sk_reasm_test mctp_route_input_sk_reasm_tests[] = {
508 	{
509 		.name = "single packet",
510 		.hdrs = {
511 			RX_FRAG(FL_S | FL_E, 0),
512 		},
513 		.n_hdrs = 1,
514 		.rx_len = 1,
515 	},
516 	{
517 		.name = "single packet, offset seq",
518 		.hdrs = {
519 			RX_FRAG(FL_S | FL_E, 1),
520 		},
521 		.n_hdrs = 1,
522 		.rx_len = 1,
523 	},
524 	{
525 		.name = "start & end packets",
526 		.hdrs = {
527 			RX_FRAG(FL_S, 0),
528 			RX_FRAG(FL_E, 1),
529 		},
530 		.n_hdrs = 2,
531 		.rx_len = 2,
532 	},
533 	{
534 		.name = "start & end packets, offset seq",
535 		.hdrs = {
536 			RX_FRAG(FL_S, 1),
537 			RX_FRAG(FL_E, 2),
538 		},
539 		.n_hdrs = 2,
540 		.rx_len = 2,
541 	},
542 	{
543 		.name = "start & end packets, out of order",
544 		.hdrs = {
545 			RX_FRAG(FL_E, 1),
546 			RX_FRAG(FL_S, 0),
547 		},
548 		.n_hdrs = 2,
549 		.rx_len = 0,
550 	},
551 	{
552 		.name = "start, middle & end packets",
553 		.hdrs = {
554 			RX_FRAG(FL_S, 0),
555 			RX_FRAG(0,    1),
556 			RX_FRAG(FL_E, 2),
557 		},
558 		.n_hdrs = 3,
559 		.rx_len = 3,
560 	},
561 	{
562 		.name = "missing seq",
563 		.hdrs = {
564 			RX_FRAG(FL_S, 0),
565 			RX_FRAG(FL_E, 2),
566 		},
567 		.n_hdrs = 2,
568 		.rx_len = 0,
569 	},
570 	{
571 		.name = "seq wrap",
572 		.hdrs = {
573 			RX_FRAG(FL_S, 3),
574 			RX_FRAG(FL_E, 0),
575 		},
576 		.n_hdrs = 2,
577 		.rx_len = 2,
578 	},
579 };
580 
581 static void mctp_route_input_sk_reasm_to_desc(
582 				const struct mctp_route_input_sk_reasm_test *t,
583 				char *desc)
584 {
585 	sprintf(desc, "%s", t->name);
586 }
587 
588 KUNIT_ARRAY_PARAM(mctp_route_input_sk_reasm, mctp_route_input_sk_reasm_tests,
589 		  mctp_route_input_sk_reasm_to_desc);
590 
591 struct mctp_route_input_sk_keys_test {
592 	const char	*name;
593 	mctp_eid_t	key_peer_addr;
594 	mctp_eid_t	key_local_addr;
595 	u8		key_tag;
596 	struct mctp_hdr hdr;
597 	bool		deliver;
598 };
599 
600 /* test packet rx in the presence of various key configurations */
601 static void mctp_test_route_input_sk_keys(struct kunit *test)
602 {
603 	const struct mctp_route_input_sk_keys_test *params;
604 	struct mctp_test_pktqueue tpq;
605 	struct sk_buff *skb, *skb2;
606 	struct mctp_test_dev *dev;
607 	struct mctp_sk_key *key;
608 	struct netns_mctp *mns;
609 	struct mctp_sock *msk;
610 	struct socket *sock;
611 	unsigned long flags;
612 	struct mctp_dst dst;
613 	unsigned int net;
614 	int rc;
615 	u8 c;
616 
617 	params = test->param_value;
618 
619 	dev = mctp_test_create_dev();
620 	KUNIT_ASSERT_NOT_ERR_OR_NULL(test, dev);
621 	net = READ_ONCE(dev->mdev->net);
622 
623 	mctp_test_dst_setup(test, &dst, dev, &tpq, 68);
624 
625 	rc = sock_create_kern(&init_net, AF_MCTP, SOCK_DGRAM, 0, &sock);
626 	KUNIT_ASSERT_EQ(test, rc, 0);
627 
628 	msk = container_of(sock->sk, struct mctp_sock, sk);
629 	mns = &sock_net(sock->sk)->mctp;
630 
631 	/* set the incoming tag according to test params */
632 	key = mctp_key_alloc(msk, net, params->key_local_addr,
633 			     params->key_peer_addr, params->key_tag,
634 			     GFP_KERNEL);
635 
636 	KUNIT_ASSERT_NOT_ERR_OR_NULL(test, key);
637 
638 	spin_lock_irqsave(&mns->keys_lock, flags);
639 	mctp_reserve_tag(&init_net, key, msk);
640 	spin_unlock_irqrestore(&mns->keys_lock, flags);
641 
642 	/* create packet and route */
643 	c = 0;
644 	skb = mctp_test_create_skb_data(&params->hdr, &c);
645 	KUNIT_ASSERT_NOT_ERR_OR_NULL(test, skb);
646 
647 	mctp_test_skb_set_dev(skb, dev);
648 
649 	rc = mctp_dst_input(&dst, skb);
650 
651 	/* (potentially) receive message */
652 	skb2 = skb_recv_datagram(sock->sk, MSG_DONTWAIT, &rc);
653 
654 	if (params->deliver)
655 		KUNIT_EXPECT_NOT_ERR_OR_NULL(test, skb2);
656 	else
657 		KUNIT_EXPECT_PTR_EQ(test, skb2, NULL);
658 
659 	if (skb2)
660 		skb_free_datagram(sock->sk, skb2);
661 
662 	mctp_key_unref(key);
663 	__mctp_route_test_fini(test, dev, &dst, &tpq, sock);
664 }
665 
666 static const struct mctp_route_input_sk_keys_test mctp_route_input_sk_keys_tests[] = {
667 	{
668 		.name = "direct match",
669 		.key_peer_addr = 9,
670 		.key_local_addr = 8,
671 		.key_tag = 1,
672 		.hdr = RX_HDR(1, 9, 8, FL_S | FL_E | FL_T(1)),
673 		.deliver = true,
674 	},
675 	{
676 		.name = "flipped src/dest",
677 		.key_peer_addr = 8,
678 		.key_local_addr = 9,
679 		.key_tag = 1,
680 		.hdr = RX_HDR(1, 9, 8, FL_S | FL_E | FL_T(1)),
681 		.deliver = false,
682 	},
683 	{
684 		.name = "peer addr mismatch",
685 		.key_peer_addr = 9,
686 		.key_local_addr = 8,
687 		.key_tag = 1,
688 		.hdr = RX_HDR(1, 10, 8, FL_S | FL_E | FL_T(1)),
689 		.deliver = false,
690 	},
691 	{
692 		.name = "tag value mismatch",
693 		.key_peer_addr = 9,
694 		.key_local_addr = 8,
695 		.key_tag = 1,
696 		.hdr = RX_HDR(1, 9, 8, FL_S | FL_E | FL_T(2)),
697 		.deliver = false,
698 	},
699 	{
700 		.name = "TO mismatch",
701 		.key_peer_addr = 9,
702 		.key_local_addr = 8,
703 		.key_tag = 1,
704 		.hdr = RX_HDR(1, 9, 8, FL_S | FL_E | FL_T(1) | FL_TO),
705 		.deliver = false,
706 	},
707 	{
708 		.name = "broadcast response",
709 		.key_peer_addr = MCTP_ADDR_ANY,
710 		.key_local_addr = 8,
711 		.key_tag = 1,
712 		.hdr = RX_HDR(1, 11, 8, FL_S | FL_E | FL_T(1)),
713 		.deliver = true,
714 	},
715 	{
716 		.name = "any local match",
717 		.key_peer_addr = 12,
718 		.key_local_addr = MCTP_ADDR_ANY,
719 		.key_tag = 1,
720 		.hdr = RX_HDR(1, 12, 8, FL_S | FL_E | FL_T(1)),
721 		.deliver = true,
722 	},
723 };
724 
725 static void mctp_route_input_sk_keys_to_desc(
726 				const struct mctp_route_input_sk_keys_test *t,
727 				char *desc)
728 {
729 	sprintf(desc, "%s", t->name);
730 }
731 
732 KUNIT_ARRAY_PARAM(mctp_route_input_sk_keys, mctp_route_input_sk_keys_tests,
733 		  mctp_route_input_sk_keys_to_desc);
734 
735 struct test_net {
736 	unsigned int netid;
737 	struct mctp_test_dev *dev;
738 	struct mctp_test_pktqueue tpq;
739 	struct mctp_dst dst;
740 	struct socket *sock;
741 	struct sk_buff *skb;
742 	struct mctp_sk_key *key;
743 	struct {
744 		u8 type;
745 		unsigned int data;
746 	} msg;
747 };
748 
749 static void
750 mctp_test_route_input_multiple_nets_bind_init(struct kunit *test,
751 					      struct test_net *t)
752 {
753 	struct mctp_hdr hdr = RX_HDR(1, 9, 8, FL_S | FL_E | FL_T(1) | FL_TO);
754 
755 	t->msg.data = t->netid;
756 
757 	__mctp_route_test_init(test, &t->dev, &t->dst, &t->tpq, &t->sock,
758 			       t->netid);
759 
760 	t->skb = mctp_test_create_skb_data(&hdr, &t->msg);
761 	KUNIT_ASSERT_NOT_ERR_OR_NULL(test, t->skb);
762 	mctp_test_skb_set_dev(t->skb, t->dev);
763 	mctp_test_pktqueue_init(&t->tpq);
764 }
765 
766 static void
767 mctp_test_route_input_multiple_nets_bind_fini(struct kunit *test,
768 					      struct test_net *t)
769 {
770 	__mctp_route_test_fini(test, t->dev, &t->dst, &t->tpq, t->sock);
771 }
772 
773 /* Test that skbs from different nets (otherwise identical) get routed to their
774  * corresponding socket via the sockets' bind()
775  */
776 static void mctp_test_route_input_multiple_nets_bind(struct kunit *test)
777 {
778 	struct sk_buff *rx_skb1, *rx_skb2;
779 	struct test_net t1, t2;
780 	int rc;
781 
782 	t1.netid = 1;
783 	t2.netid = 2;
784 
785 	t1.msg.type = 0;
786 	t2.msg.type = 0;
787 
788 	mctp_test_route_input_multiple_nets_bind_init(test, &t1);
789 	mctp_test_route_input_multiple_nets_bind_init(test, &t2);
790 
791 	rc = mctp_dst_input(&t1.dst, t1.skb);
792 	KUNIT_ASSERT_EQ(test, rc, 0);
793 	rc = mctp_dst_input(&t2.dst, t2.skb);
794 	KUNIT_ASSERT_EQ(test, rc, 0);
795 
796 	rx_skb1 = skb_recv_datagram(t1.sock->sk, MSG_DONTWAIT, &rc);
797 	KUNIT_EXPECT_NOT_ERR_OR_NULL(test, rx_skb1);
798 	KUNIT_EXPECT_EQ(test, rx_skb1->len, sizeof(t1.msg));
799 	KUNIT_EXPECT_EQ(test,
800 			*(unsigned int *)skb_pull(rx_skb1, sizeof(t1.msg.data)),
801 			t1.netid);
802 	kfree_skb(rx_skb1);
803 
804 	rx_skb2 = skb_recv_datagram(t2.sock->sk, MSG_DONTWAIT, &rc);
805 	KUNIT_EXPECT_NOT_ERR_OR_NULL(test, rx_skb2);
806 	KUNIT_EXPECT_EQ(test, rx_skb2->len, sizeof(t2.msg));
807 	KUNIT_EXPECT_EQ(test,
808 			*(unsigned int *)skb_pull(rx_skb2, sizeof(t2.msg.data)),
809 			t2.netid);
810 	kfree_skb(rx_skb2);
811 
812 	mctp_test_route_input_multiple_nets_bind_fini(test, &t1);
813 	mctp_test_route_input_multiple_nets_bind_fini(test, &t2);
814 }
815 
816 static void
817 mctp_test_route_input_multiple_nets_key_init(struct kunit *test,
818 					     struct test_net *t)
819 {
820 	struct mctp_hdr hdr = RX_HDR(1, 9, 8, FL_S | FL_E | FL_T(1));
821 	struct mctp_sock *msk;
822 	struct netns_mctp *mns;
823 	unsigned long flags;
824 
825 	t->msg.data = t->netid;
826 
827 	__mctp_route_test_init(test, &t->dev, &t->dst, &t->tpq, &t->sock,
828 			       t->netid);
829 
830 	msk = container_of(t->sock->sk, struct mctp_sock, sk);
831 
832 	t->key = mctp_key_alloc(msk, t->netid, hdr.dest, hdr.src, 1, GFP_KERNEL);
833 	KUNIT_ASSERT_NOT_ERR_OR_NULL(test, t->key);
834 
835 	mns = &sock_net(t->sock->sk)->mctp;
836 	spin_lock_irqsave(&mns->keys_lock, flags);
837 	mctp_reserve_tag(&init_net, t->key, msk);
838 	spin_unlock_irqrestore(&mns->keys_lock, flags);
839 
840 	KUNIT_ASSERT_NOT_ERR_OR_NULL(test, t->key);
841 	t->skb = mctp_test_create_skb_data(&hdr, &t->msg);
842 	KUNIT_ASSERT_NOT_ERR_OR_NULL(test, t->skb);
843 	mctp_test_skb_set_dev(t->skb, t->dev);
844 }
845 
846 static void
847 mctp_test_route_input_multiple_nets_key_fini(struct kunit *test,
848 					     struct test_net *t)
849 {
850 	mctp_key_unref(t->key);
851 	__mctp_route_test_fini(test, t->dev, &t->dst, &t->tpq, t->sock);
852 }
853 
854 /* test that skbs from different nets (otherwise identical) get routed to their
855  * corresponding socket via the sk_key
856  */
857 static void mctp_test_route_input_multiple_nets_key(struct kunit *test)
858 {
859 	struct sk_buff *rx_skb1, *rx_skb2;
860 	struct test_net t1, t2;
861 	int rc;
862 
863 	t1.netid = 1;
864 	t2.netid = 2;
865 
866 	/* use type 1 which is not bound */
867 	t1.msg.type = 1;
868 	t2.msg.type = 1;
869 
870 	mctp_test_route_input_multiple_nets_key_init(test, &t1);
871 	mctp_test_route_input_multiple_nets_key_init(test, &t2);
872 
873 	rc = mctp_dst_input(&t1.dst, t1.skb);
874 	KUNIT_ASSERT_EQ(test, rc, 0);
875 	rc = mctp_dst_input(&t2.dst, t2.skb);
876 	KUNIT_ASSERT_EQ(test, rc, 0);
877 
878 	rx_skb1 = skb_recv_datagram(t1.sock->sk, MSG_DONTWAIT, &rc);
879 	KUNIT_EXPECT_NOT_ERR_OR_NULL(test, rx_skb1);
880 	KUNIT_EXPECT_EQ(test, rx_skb1->len, sizeof(t1.msg));
881 	KUNIT_EXPECT_EQ(test,
882 			*(unsigned int *)skb_pull(rx_skb1, sizeof(t1.msg.data)),
883 			t1.netid);
884 	kfree_skb(rx_skb1);
885 
886 	rx_skb2 = skb_recv_datagram(t2.sock->sk, MSG_DONTWAIT, &rc);
887 	KUNIT_EXPECT_NOT_ERR_OR_NULL(test, rx_skb2);
888 	KUNIT_EXPECT_EQ(test, rx_skb2->len, sizeof(t2.msg));
889 	KUNIT_EXPECT_EQ(test,
890 			*(unsigned int *)skb_pull(rx_skb2, sizeof(t2.msg.data)),
891 			t2.netid);
892 	kfree_skb(rx_skb2);
893 
894 	mctp_test_route_input_multiple_nets_key_fini(test, &t1);
895 	mctp_test_route_input_multiple_nets_key_fini(test, &t2);
896 }
897 
898 /* Input route to socket, using a single-packet message, where sock delivery
899  * fails. Ensure we're handling the failure appropriately.
900  */
901 static void mctp_test_route_input_sk_fail_single(struct kunit *test)
902 {
903 	const struct mctp_hdr hdr = RX_HDR(1, 10, 8, FL_S | FL_E | FL_TO);
904 	struct mctp_test_pktqueue tpq;
905 	struct mctp_test_dev *dev;
906 	struct mctp_dst dst;
907 	struct socket *sock;
908 	struct sk_buff *skb;
909 	int rc;
910 
911 	__mctp_route_test_init(test, &dev, &dst, &tpq, &sock, MCTP_NET_ANY);
912 
913 	/* No rcvbuf space, so delivery should fail. __sock_set_rcvbuf will
914 	 * clamp the minimum to SOCK_MIN_RCVBUF, so we open-code this.
915 	 */
916 	lock_sock(sock->sk);
917 	WRITE_ONCE(sock->sk->sk_rcvbuf, 0);
918 	release_sock(sock->sk);
919 
920 	skb = mctp_test_create_skb(&hdr, 10);
921 	KUNIT_ASSERT_NOT_ERR_OR_NULL(test, skb);
922 	skb_get(skb);
923 
924 	mctp_test_skb_set_dev(skb, dev);
925 
926 	/* do route input, which should fail */
927 	rc = mctp_dst_input(&dst, skb);
928 	KUNIT_EXPECT_NE(test, rc, 0);
929 
930 	/* we should hold the only reference to skb */
931 	KUNIT_EXPECT_EQ(test, refcount_read(&skb->users), 1);
932 	kfree_skb(skb);
933 
934 	__mctp_route_test_fini(test, dev, &dst, &tpq, sock);
935 }
936 
937 /* Input route to socket, using a fragmented message, where sock delivery fails.
938  */
939 static void mctp_test_route_input_sk_fail_frag(struct kunit *test)
940 {
941 	const struct mctp_hdr hdrs[2] = { RX_FRAG(FL_S, 0), RX_FRAG(FL_E, 1) };
942 	struct mctp_test_pktqueue tpq;
943 	struct mctp_test_dev *dev;
944 	struct sk_buff *skbs[2];
945 	struct mctp_dst dst;
946 	struct socket *sock;
947 	unsigned int i;
948 	int rc;
949 
950 	__mctp_route_test_init(test, &dev, &dst, &tpq, &sock, MCTP_NET_ANY);
951 
952 	lock_sock(sock->sk);
953 	WRITE_ONCE(sock->sk->sk_rcvbuf, 0);
954 	release_sock(sock->sk);
955 
956 	for (i = 0; i < ARRAY_SIZE(skbs); i++) {
957 		skbs[i] = mctp_test_create_skb(&hdrs[i], 10);
958 		KUNIT_ASSERT_NOT_ERR_OR_NULL(test, skbs[i]);
959 		skb_get(skbs[i]);
960 
961 		mctp_test_skb_set_dev(skbs[i], dev);
962 	}
963 
964 	/* first route input should succeed, we're only queueing to the
965 	 * frag list
966 	 */
967 	rc = mctp_dst_input(&dst, skbs[0]);
968 	KUNIT_EXPECT_EQ(test, rc, 0);
969 
970 	/* final route input should fail to deliver to the socket */
971 	rc = mctp_dst_input(&dst, skbs[1]);
972 	KUNIT_EXPECT_NE(test, rc, 0);
973 
974 	/* we should hold the only reference to both skbs */
975 	KUNIT_EXPECT_EQ(test, refcount_read(&skbs[0]->users), 1);
976 	kfree_skb(skbs[0]);
977 
978 	KUNIT_EXPECT_EQ(test, refcount_read(&skbs[1]->users), 1);
979 	kfree_skb(skbs[1]);
980 
981 	__mctp_route_test_fini(test, dev, &dst, &tpq, sock);
982 }
983 
984 /* Input route to socket, using a fragmented message created from clones.
985  */
986 static void mctp_test_route_input_cloned_frag(struct kunit *test)
987 {
988 	/* 5 packet fragments, forming 2 complete messages */
989 	const struct mctp_hdr hdrs[5] = {
990 		RX_FRAG(FL_S, 0),
991 		RX_FRAG(0, 1),
992 		RX_FRAG(FL_E, 2),
993 		RX_FRAG(FL_S, 0),
994 		RX_FRAG(FL_E, 1),
995 	};
996 	const size_t data_len = 3; /* arbitrary */
997 	u8 compare[3 * ARRAY_SIZE(hdrs)];
998 	u8 flat[3 * ARRAY_SIZE(hdrs)];
999 	struct mctp_test_pktqueue tpq;
1000 	struct mctp_test_dev *dev;
1001 	struct sk_buff *skb[5];
1002 	struct sk_buff *rx_skb;
1003 	struct mctp_dst dst;
1004 	struct socket *sock;
1005 	size_t total;
1006 	void *p;
1007 	int rc;
1008 
1009 	total = data_len + sizeof(struct mctp_hdr);
1010 
1011 	__mctp_route_test_init(test, &dev, &dst, &tpq, &sock, MCTP_NET_ANY);
1012 
1013 	/* Create a single skb initially with concatenated packets */
1014 	skb[0] = mctp_test_create_skb(&hdrs[0], 5 * total);
1015 	mctp_test_skb_set_dev(skb[0], dev);
1016 	memset(skb[0]->data, 0 * 0x11, skb[0]->len);
1017 	memcpy(skb[0]->data, &hdrs[0], sizeof(struct mctp_hdr));
1018 
1019 	/* Extract and populate packets */
1020 	for (int i = 1; i < 5; i++) {
1021 		skb[i] = skb_clone(skb[i - 1], GFP_ATOMIC);
1022 		KUNIT_ASSERT_TRUE(test, skb[i]);
1023 		p = skb_pull(skb[i], total);
1024 		KUNIT_ASSERT_TRUE(test, p);
1025 		skb_reset_network_header(skb[i]);
1026 		memcpy(skb[i]->data, &hdrs[i], sizeof(struct mctp_hdr));
1027 		memset(&skb[i]->data[sizeof(struct mctp_hdr)], i * 0x11, data_len);
1028 	}
1029 	for (int i = 0; i < 5; i++)
1030 		skb_trim(skb[i], total);
1031 
1032 	/* SOM packets have a type byte to match the socket */
1033 	skb[0]->data[4] = 0;
1034 	skb[3]->data[4] = 0;
1035 
1036 	skb_dump("pkt1 ", skb[0], false);
1037 	skb_dump("pkt2 ", skb[1], false);
1038 	skb_dump("pkt3 ", skb[2], false);
1039 	skb_dump("pkt4 ", skb[3], false);
1040 	skb_dump("pkt5 ", skb[4], false);
1041 
1042 	for (int i = 0; i < 5; i++) {
1043 		KUNIT_EXPECT_EQ(test, refcount_read(&skb[i]->users), 1);
1044 		/* Take a reference so we can check refcounts at the end */
1045 		skb_get(skb[i]);
1046 	}
1047 
1048 	/* Feed the fragments into MCTP core */
1049 	for (int i = 0; i < 5; i++) {
1050 		rc = mctp_dst_input(&dst, skb[i]);
1051 		KUNIT_EXPECT_EQ(test, rc, 0);
1052 	}
1053 
1054 	/* Receive first reassembled message */
1055 	rx_skb = skb_recv_datagram(sock->sk, MSG_DONTWAIT, &rc);
1056 	KUNIT_EXPECT_EQ(test, rc, 0);
1057 	KUNIT_EXPECT_EQ(test, rx_skb->len, 3 * data_len);
1058 	rc = skb_copy_bits(rx_skb, 0, flat, rx_skb->len);
1059 	for (int i = 0; i < rx_skb->len; i++)
1060 		compare[i] = (i / data_len) * 0x11;
1061 	/* Set type byte */
1062 	compare[0] = 0;
1063 
1064 	KUNIT_EXPECT_MEMEQ(test, flat, compare, rx_skb->len);
1065 	KUNIT_EXPECT_EQ(test, refcount_read(&rx_skb->users), 1);
1066 	kfree_skb(rx_skb);
1067 
1068 	/* Receive second reassembled message */
1069 	rx_skb = skb_recv_datagram(sock->sk, MSG_DONTWAIT, &rc);
1070 	KUNIT_EXPECT_EQ(test, rc, 0);
1071 	KUNIT_EXPECT_EQ(test, rx_skb->len, 2 * data_len);
1072 	rc = skb_copy_bits(rx_skb, 0, flat, rx_skb->len);
1073 	for (int i = 0; i < rx_skb->len; i++)
1074 		compare[i] = (i / data_len + 3) * 0x11;
1075 	/* Set type byte */
1076 	compare[0] = 0;
1077 
1078 	KUNIT_EXPECT_MEMEQ(test, flat, compare, rx_skb->len);
1079 	KUNIT_EXPECT_EQ(test, refcount_read(&rx_skb->users), 1);
1080 	kfree_skb(rx_skb);
1081 
1082 	/* Check input skb refcounts */
1083 	for (int i = 0; i < 5; i++) {
1084 		KUNIT_EXPECT_EQ(test, refcount_read(&skb[i]->users), 1);
1085 		kfree_skb(skb[i]);
1086 	}
1087 
1088 	__mctp_route_test_fini(test, dev, &dst, &tpq, sock);
1089 }
1090 
1091 #if IS_ENABLED(CONFIG_MCTP_FLOWS)
1092 
1093 static void mctp_test_flow_init(struct kunit *test,
1094 				struct mctp_test_dev **devp,
1095 				struct mctp_dst *dst,
1096 				struct mctp_test_pktqueue *tpq,
1097 				struct socket **sock,
1098 				struct sk_buff **skbp,
1099 				unsigned int len)
1100 {
1101 	struct mctp_test_dev *dev;
1102 	struct sk_buff *skb;
1103 
1104 	/* we have a slightly odd routing setup here; the test route
1105 	 * is for EID 8, which is our local EID. We don't do a routing
1106 	 * lookup, so that's fine - all we require is a path through
1107 	 * mctp_local_output, which will call dst->output on whatever
1108 	 * route we provide
1109 	 */
1110 	__mctp_route_test_init(test, &dev, dst, tpq, sock, MCTP_NET_ANY);
1111 
1112 	/* Assign a single EID. ->addrs is freed on mctp netdev release */
1113 	dev->mdev->addrs = kmalloc(sizeof(u8), GFP_KERNEL);
1114 	dev->mdev->num_addrs = 1;
1115 	dev->mdev->addrs[0] = 8;
1116 
1117 	skb = alloc_skb(len + sizeof(struct mctp_hdr) + 1, GFP_KERNEL);
1118 	KUNIT_ASSERT_TRUE(test, skb);
1119 	__mctp_cb(skb);
1120 	skb_reserve(skb, sizeof(struct mctp_hdr) + 1);
1121 	memset(skb_put(skb, len), 0, len);
1122 
1123 
1124 	*devp = dev;
1125 	*skbp = skb;
1126 }
1127 
1128 static void mctp_test_flow_fini(struct kunit *test,
1129 				struct mctp_test_dev *dev,
1130 				struct mctp_dst *dst,
1131 				struct mctp_test_pktqueue *tpq,
1132 				struct socket *sock)
1133 {
1134 	__mctp_route_test_fini(test, dev, dst, tpq, sock);
1135 }
1136 
1137 /* test that an outgoing skb has the correct MCTP extension data set */
1138 static void mctp_test_packet_flow(struct kunit *test)
1139 {
1140 	struct mctp_test_pktqueue tpq;
1141 	struct sk_buff *skb, *skb2;
1142 	struct mctp_test_dev *dev;
1143 	struct mctp_dst dst;
1144 	struct mctp_flow *flow;
1145 	struct socket *sock;
1146 	u8 dst_eid = 8;
1147 	int n, rc;
1148 
1149 	mctp_test_flow_init(test, &dev, &dst, &tpq, &sock, &skb, 30);
1150 
1151 	rc = mctp_local_output(sock->sk, &dst, skb, dst_eid, MCTP_TAG_OWNER);
1152 	KUNIT_ASSERT_EQ(test, rc, 0);
1153 
1154 	n = tpq.pkts.qlen;
1155 	KUNIT_ASSERT_EQ(test, n, 1);
1156 
1157 	skb2 = skb_dequeue(&tpq.pkts);
1158 	KUNIT_ASSERT_NOT_ERR_OR_NULL(test, skb2);
1159 
1160 	flow = skb_ext_find(skb2, SKB_EXT_MCTP);
1161 	KUNIT_ASSERT_NOT_ERR_OR_NULL(test, flow);
1162 	KUNIT_ASSERT_NOT_ERR_OR_NULL(test, flow->key);
1163 	KUNIT_ASSERT_PTR_EQ(test, flow->key->sk, sock->sk);
1164 
1165 	kfree_skb(skb2);
1166 	mctp_test_flow_fini(test, dev, &dst, &tpq, sock);
1167 }
1168 
1169 /* test that outgoing skbs, after fragmentation, all have the correct MCTP
1170  * extension data set.
1171  */
1172 static void mctp_test_fragment_flow(struct kunit *test)
1173 {
1174 	struct mctp_test_pktqueue tpq;
1175 	struct mctp_flow *flows[2];
1176 	struct sk_buff *tx_skbs[2];
1177 	struct mctp_test_dev *dev;
1178 	struct mctp_dst dst;
1179 	struct sk_buff *skb;
1180 	struct socket *sock;
1181 	u8 dst_eid = 8;
1182 	int n, rc;
1183 
1184 	mctp_test_flow_init(test, &dev, &dst, &tpq, &sock, &skb, 100);
1185 
1186 	rc = mctp_local_output(sock->sk, &dst, skb, dst_eid, MCTP_TAG_OWNER);
1187 	KUNIT_ASSERT_EQ(test, rc, 0);
1188 
1189 	n = tpq.pkts.qlen;
1190 	KUNIT_ASSERT_EQ(test, n, 2);
1191 
1192 	/* both resulting packets should have the same flow data */
1193 	tx_skbs[0] = skb_dequeue(&tpq.pkts);
1194 	tx_skbs[1] = skb_dequeue(&tpq.pkts);
1195 
1196 	KUNIT_ASSERT_NOT_ERR_OR_NULL(test, tx_skbs[0]);
1197 	KUNIT_ASSERT_NOT_ERR_OR_NULL(test, tx_skbs[1]);
1198 
1199 	flows[0] = skb_ext_find(tx_skbs[0], SKB_EXT_MCTP);
1200 	KUNIT_ASSERT_NOT_ERR_OR_NULL(test, flows[0]);
1201 	KUNIT_ASSERT_NOT_ERR_OR_NULL(test, flows[0]->key);
1202 	KUNIT_ASSERT_PTR_EQ(test, flows[0]->key->sk, sock->sk);
1203 
1204 	flows[1] = skb_ext_find(tx_skbs[1], SKB_EXT_MCTP);
1205 	KUNIT_ASSERT_NOT_ERR_OR_NULL(test, flows[1]);
1206 	KUNIT_ASSERT_PTR_EQ(test, flows[1]->key, flows[0]->key);
1207 
1208 	kfree_skb(tx_skbs[0]);
1209 	kfree_skb(tx_skbs[1]);
1210 	mctp_test_flow_fini(test, dev, &dst, &tpq, sock);
1211 }
1212 
1213 #else
1214 static void mctp_test_packet_flow(struct kunit *test)
1215 {
1216 	kunit_skip(test, "Requires CONFIG_MCTP_FLOWS=y");
1217 }
1218 
1219 static void mctp_test_fragment_flow(struct kunit *test)
1220 {
1221 	kunit_skip(test, "Requires CONFIG_MCTP_FLOWS=y");
1222 }
1223 #endif
1224 
1225 /* Test that outgoing skbs cause a suitable tag to be created */
1226 static void mctp_test_route_output_key_create(struct kunit *test)
1227 {
1228 	const u8 dst_eid = 26, src_eid = 15;
1229 	struct mctp_test_pktqueue tpq;
1230 	const unsigned int netid = 50;
1231 	struct mctp_test_dev *dev;
1232 	struct mctp_sk_key *key;
1233 	struct netns_mctp *mns;
1234 	unsigned long flags;
1235 	struct socket *sock;
1236 	struct sk_buff *skb;
1237 	struct mctp_dst dst;
1238 	bool empty, single;
1239 	const int len = 2;
1240 	int rc;
1241 
1242 	dev = mctp_test_create_dev();
1243 	KUNIT_ASSERT_NOT_ERR_OR_NULL(test, dev);
1244 	WRITE_ONCE(dev->mdev->net, netid);
1245 
1246 	mctp_test_dst_setup(test, &dst, dev, &tpq, 68);
1247 
1248 	rc = sock_create_kern(&init_net, AF_MCTP, SOCK_DGRAM, 0, &sock);
1249 	KUNIT_ASSERT_EQ(test, rc, 0);
1250 
1251 	dev->mdev->addrs = kmalloc(sizeof(u8), GFP_KERNEL);
1252 	dev->mdev->num_addrs = 1;
1253 	dev->mdev->addrs[0] = src_eid;
1254 
1255 	skb = alloc_skb(sizeof(struct mctp_hdr) + 1 + len, GFP_KERNEL);
1256 	KUNIT_ASSERT_TRUE(test, skb);
1257 	__mctp_cb(skb);
1258 	skb_reserve(skb, sizeof(struct mctp_hdr) + 1 + len);
1259 	memset(skb_put(skb, len), 0, len);
1260 
1261 	mns = &sock_net(sock->sk)->mctp;
1262 
1263 	/* We assume we're starting from an empty keys list, which requires
1264 	 * preceding tests to clean up correctly!
1265 	 */
1266 	spin_lock_irqsave(&mns->keys_lock, flags);
1267 	empty = hlist_empty(&mns->keys);
1268 	spin_unlock_irqrestore(&mns->keys_lock, flags);
1269 	KUNIT_ASSERT_TRUE(test, empty);
1270 
1271 	rc = mctp_local_output(sock->sk, &dst, skb, dst_eid, MCTP_TAG_OWNER);
1272 	KUNIT_ASSERT_EQ(test, rc, 0);
1273 
1274 	key = NULL;
1275 	single = false;
1276 	spin_lock_irqsave(&mns->keys_lock, flags);
1277 	if (!hlist_empty(&mns->keys)) {
1278 		key = hlist_entry(mns->keys.first, struct mctp_sk_key, hlist);
1279 		single = hlist_is_singular_node(&key->hlist, &mns->keys);
1280 	}
1281 	spin_unlock_irqrestore(&mns->keys_lock, flags);
1282 
1283 	KUNIT_ASSERT_NOT_NULL(test, key);
1284 	KUNIT_ASSERT_TRUE(test, single);
1285 
1286 	KUNIT_EXPECT_EQ(test, key->net, netid);
1287 	KUNIT_EXPECT_EQ(test, key->local_addr, src_eid);
1288 	KUNIT_EXPECT_EQ(test, key->peer_addr, dst_eid);
1289 	/* key has incoming tag, so inverse of what we sent */
1290 	KUNIT_EXPECT_FALSE(test, key->tag & MCTP_TAG_OWNER);
1291 
1292 	sock_release(sock);
1293 	mctp_test_dst_release(&dst, &tpq);
1294 	mctp_test_destroy_dev(dev);
1295 }
1296 
1297 static void mctp_test_route_extaddr_input(struct kunit *test)
1298 {
1299 	static const unsigned char haddr[] = { 0xaa, 0x55 };
1300 	struct mctp_test_pktqueue tpq;
1301 	struct mctp_skb_cb *cb, *cb2;
1302 	const unsigned int len = 40;
1303 	struct mctp_test_dev *dev;
1304 	struct sk_buff *skb, *skb2;
1305 	struct mctp_dst dst;
1306 	struct mctp_hdr hdr;
1307 	struct socket *sock;
1308 	int rc;
1309 
1310 	hdr.ver = 1;
1311 	hdr.src = 10;
1312 	hdr.dest = 8;
1313 	hdr.flags_seq_tag = FL_S | FL_E | FL_TO;
1314 
1315 	__mctp_route_test_init(test, &dev, &dst, &tpq, &sock, MCTP_NET_ANY);
1316 
1317 	skb = mctp_test_create_skb(&hdr, len);
1318 	KUNIT_ASSERT_NOT_ERR_OR_NULL(test, skb);
1319 
1320 	/* set our hardware addressing data */
1321 	cb = mctp_cb(skb);
1322 	memcpy(cb->haddr, haddr, sizeof(haddr));
1323 	cb->halen = sizeof(haddr);
1324 
1325 	mctp_test_skb_set_dev(skb, dev);
1326 
1327 	rc = mctp_dst_input(&dst, skb);
1328 	KUNIT_ASSERT_EQ(test, rc, 0);
1329 
1330 	mctp_test_dst_release(&dst, &tpq);
1331 
1332 	skb2 = skb_recv_datagram(sock->sk, MSG_DONTWAIT, &rc);
1333 	KUNIT_ASSERT_NOT_ERR_OR_NULL(test, skb2);
1334 	KUNIT_ASSERT_EQ(test, skb2->len, len);
1335 
1336 	cb2 = mctp_cb(skb2);
1337 
1338 	/* Received SKB should have the hardware addressing as set above.
1339 	 * We're likely to have the same actual cb here (ie., cb == cb2),
1340 	 * but it's the comparison that we care about
1341 	 */
1342 	KUNIT_EXPECT_EQ(test, cb2->halen, sizeof(haddr));
1343 	KUNIT_EXPECT_MEMEQ(test, cb2->haddr, haddr, sizeof(haddr));
1344 
1345 	skb_free_datagram(sock->sk, skb2);
1346 	mctp_test_destroy_dev(dev);
1347 }
1348 
1349 static struct kunit_case mctp_test_cases[] = {
1350 	KUNIT_CASE_PARAM(mctp_test_fragment, mctp_frag_gen_params),
1351 	KUNIT_CASE_PARAM(mctp_test_rx_input, mctp_rx_input_gen_params),
1352 	KUNIT_CASE_PARAM(mctp_test_route_input_sk, mctp_route_input_sk_gen_params),
1353 	KUNIT_CASE_PARAM(mctp_test_route_input_sk_reasm,
1354 			 mctp_route_input_sk_reasm_gen_params),
1355 	KUNIT_CASE_PARAM(mctp_test_route_input_sk_keys,
1356 			 mctp_route_input_sk_keys_gen_params),
1357 	KUNIT_CASE(mctp_test_route_input_sk_fail_single),
1358 	KUNIT_CASE(mctp_test_route_input_sk_fail_frag),
1359 	KUNIT_CASE(mctp_test_route_input_multiple_nets_bind),
1360 	KUNIT_CASE(mctp_test_route_input_multiple_nets_key),
1361 	KUNIT_CASE(mctp_test_packet_flow),
1362 	KUNIT_CASE(mctp_test_fragment_flow),
1363 	KUNIT_CASE(mctp_test_route_output_key_create),
1364 	KUNIT_CASE(mctp_test_route_input_cloned_frag),
1365 	KUNIT_CASE(mctp_test_route_extaddr_input),
1366 	{}
1367 };
1368 
1369 static struct kunit_suite mctp_test_suite = {
1370 	.name = "mctp",
1371 	.test_cases = mctp_test_cases,
1372 };
1373 
1374 kunit_test_suite(mctp_test_suite);
1375