xref: /linux/samples/bpf/hbm.c (revision 3f0a50f345f78183f6e9b39c2f45ca5dcaa511ca)
1 // SPDX-License-Identifier: GPL-2.0
2 /* Copyright (c) 2019 Facebook
3  *
4  * This program is free software; you can redistribute it and/or
5  * modify it under the terms of version 2 of the GNU General Public
6  * License as published by the Free Software Foundation.
7  *
8  * Example program for Host Bandwidth Managment
9  *
10  * This program loads a cgroup skb BPF program to enforce cgroup output
11  * (egress) or input (ingress) bandwidth limits.
12  *
13  * USAGE: hbm [-d] [-l] [-n <id>] [-r <rate>] [-s] [-t <secs>] [-w] [-h] [prog]
14  *   Where:
15  *    -d	Print BPF trace debug buffer
16  *    -l	Also limit flows doing loopback
17  *    -n <#>	To create cgroup \"/hbm#\" and attach prog
18  *		Default is /hbm1
19  *    --no_cn   Do not return cn notifications
20  *    -r <rate>	Rate limit in Mbps
21  *    -s	Get HBM stats (marked, dropped, etc.)
22  *    -t <time>	Exit after specified seconds (default is 0)
23  *    -w	Work conserving flag. cgroup can increase its bandwidth
24  *		beyond the rate limit specified while there is available
25  *		bandwidth. Current implementation assumes there is only
26  *		NIC (eth0), but can be extended to support multiple NICs.
27  *		Currrently only supported for egress.
28  *    -h	Print this info
29  *    prog	BPF program file name. Name defaults to hbm_out_kern.o
30  */
31 
32 #define _GNU_SOURCE
33 
34 #include <stdio.h>
35 #include <stdlib.h>
36 #include <assert.h>
37 #include <sys/resource.h>
38 #include <sys/time.h>
39 #include <unistd.h>
40 #include <errno.h>
41 #include <fcntl.h>
42 #include <linux/unistd.h>
43 #include <linux/compiler.h>
44 
45 #include <linux/bpf.h>
46 #include <bpf/bpf.h>
47 #include <getopt.h>
48 
49 #include "bpf_rlimit.h"
50 #include "cgroup_helpers.h"
51 #include "hbm.h"
52 #include "bpf_util.h"
53 #include <bpf/libbpf.h>
54 
55 bool outFlag = true;
56 int minRate = 1000;		/* cgroup rate limit in Mbps */
57 int rate = 1000;		/* can grow if rate conserving is enabled */
58 int dur = 1;
59 bool stats_flag;
60 bool loopback_flag;
61 bool debugFlag;
62 bool work_conserving_flag;
63 bool no_cn_flag;
64 bool edt_flag;
65 
66 static void Usage(void);
67 static void read_trace_pipe2(void);
68 static void do_error(char *msg, bool errno_flag);
69 
70 #define DEBUGFS "/sys/kernel/debug/tracing/"
71 
72 static struct bpf_program *bpf_prog;
73 static struct bpf_object *obj;
74 static int queue_stats_fd;
75 
76 static void read_trace_pipe2(void)
77 {
78 	int trace_fd;
79 	FILE *outf;
80 	char *outFname = "hbm_out.log";
81 
82 	trace_fd = open(DEBUGFS "trace_pipe", O_RDONLY, 0);
83 	if (trace_fd < 0) {
84 		printf("Error opening trace_pipe\n");
85 		return;
86 	}
87 
88 //	Future support of ingress
89 //	if (!outFlag)
90 //		outFname = "hbm_in.log";
91 	outf = fopen(outFname, "w");
92 
93 	if (outf == NULL)
94 		printf("Error creating %s\n", outFname);
95 
96 	while (1) {
97 		static char buf[4097];
98 		ssize_t sz;
99 
100 		sz = read(trace_fd, buf, sizeof(buf) - 1);
101 		if (sz > 0) {
102 			buf[sz] = 0;
103 			puts(buf);
104 			if (outf != NULL) {
105 				fprintf(outf, "%s\n", buf);
106 				fflush(outf);
107 			}
108 		}
109 	}
110 }
111 
112 static void do_error(char *msg, bool errno_flag)
113 {
114 	if (errno_flag)
115 		printf("ERROR: %s, errno: %d\n", msg, errno);
116 	else
117 		printf("ERROR: %s\n", msg);
118 	exit(1);
119 }
120 
121 static int prog_load(char *prog)
122 {
123 	struct bpf_program *pos;
124 	const char *sec_name;
125 
126 	obj = bpf_object__open_file(prog, NULL);
127 	if (libbpf_get_error(obj)) {
128 		printf("ERROR: opening BPF object file failed\n");
129 		return 1;
130 	}
131 
132 	/* load BPF program */
133 	if (bpf_object__load(obj)) {
134 		printf("ERROR: loading BPF object file failed\n");
135 		goto err;
136 	}
137 
138 	bpf_object__for_each_program(pos, obj) {
139 		sec_name = bpf_program__section_name(pos);
140 		if (sec_name && !strcmp(sec_name, "cgroup_skb/egress")) {
141 			bpf_prog = pos;
142 			break;
143 		}
144 	}
145 	if (!bpf_prog) {
146 		printf("ERROR: finding a prog in obj file failed\n");
147 		goto err;
148 	}
149 
150 	queue_stats_fd = bpf_object__find_map_fd_by_name(obj, "queue_stats");
151 	if (queue_stats_fd < 0) {
152 		printf("ERROR: finding a map in obj file failed\n");
153 		goto err;
154 	}
155 
156 	return 0;
157 
158 err:
159 	bpf_object__close(obj);
160 	return 1;
161 }
162 
163 static int run_bpf_prog(char *prog, int cg_id)
164 {
165 	struct hbm_queue_stats qstats = {0};
166 	char cg_dir[100], cg_pin_path[100];
167 	struct bpf_link *link = NULL;
168 	int key = 0;
169 	int cg1 = 0;
170 	int rc = 0;
171 
172 	sprintf(cg_dir, "/hbm%d", cg_id);
173 	rc = prog_load(prog);
174 	if (rc != 0)
175 		return rc;
176 
177 	if (setup_cgroup_environment()) {
178 		printf("ERROR: setting cgroup environment\n");
179 		goto err;
180 	}
181 	cg1 = create_and_get_cgroup(cg_dir);
182 	if (!cg1) {
183 		printf("ERROR: create_and_get_cgroup\n");
184 		goto err;
185 	}
186 	if (join_cgroup(cg_dir)) {
187 		printf("ERROR: join_cgroup\n");
188 		goto err;
189 	}
190 
191 	qstats.rate = rate;
192 	qstats.stats = stats_flag ? 1 : 0;
193 	qstats.loopback = loopback_flag ? 1 : 0;
194 	qstats.no_cn = no_cn_flag ? 1 : 0;
195 	if (bpf_map_update_elem(queue_stats_fd, &key, &qstats, BPF_ANY)) {
196 		printf("ERROR: Could not update map element\n");
197 		goto err;
198 	}
199 
200 	if (!outFlag)
201 		bpf_program__set_expected_attach_type(bpf_prog, BPF_CGROUP_INET_INGRESS);
202 
203 	link = bpf_program__attach_cgroup(bpf_prog, cg1);
204 	if (libbpf_get_error(link)) {
205 		fprintf(stderr, "ERROR: bpf_program__attach_cgroup failed\n");
206 		goto err;
207 	}
208 
209 	sprintf(cg_pin_path, "/sys/fs/bpf/hbm%d", cg_id);
210 	rc = bpf_link__pin(link, cg_pin_path);
211 	if (rc < 0) {
212 		printf("ERROR: bpf_link__pin failed: %d\n", rc);
213 		goto err;
214 	}
215 
216 	if (work_conserving_flag) {
217 		struct timeval t0, t_last, t_new;
218 		FILE *fin;
219 		unsigned long long last_eth_tx_bytes, new_eth_tx_bytes;
220 		signed long long last_cg_tx_bytes, new_cg_tx_bytes;
221 		signed long long delta_time, delta_bytes, delta_rate;
222 		int delta_ms;
223 #define DELTA_RATE_CHECK 10000		/* in us */
224 #define RATE_THRESHOLD 9500000000	/* 9.5 Gbps */
225 
226 		bpf_map_lookup_elem(queue_stats_fd, &key, &qstats);
227 		if (gettimeofday(&t0, NULL) < 0)
228 			do_error("gettimeofday failed", true);
229 		t_last = t0;
230 		fin = fopen("/sys/class/net/eth0/statistics/tx_bytes", "r");
231 		if (fscanf(fin, "%llu", &last_eth_tx_bytes) != 1)
232 			do_error("fscanf fails", false);
233 		fclose(fin);
234 		last_cg_tx_bytes = qstats.bytes_total;
235 		while (true) {
236 			usleep(DELTA_RATE_CHECK);
237 			if (gettimeofday(&t_new, NULL) < 0)
238 				do_error("gettimeofday failed", true);
239 			delta_ms = (t_new.tv_sec - t0.tv_sec) * 1000 +
240 				(t_new.tv_usec - t0.tv_usec)/1000;
241 			if (delta_ms > dur * 1000)
242 				break;
243 			delta_time = (t_new.tv_sec - t_last.tv_sec) * 1000000 +
244 				(t_new.tv_usec - t_last.tv_usec);
245 			if (delta_time == 0)
246 				continue;
247 			t_last = t_new;
248 			fin = fopen("/sys/class/net/eth0/statistics/tx_bytes",
249 				    "r");
250 			if (fscanf(fin, "%llu", &new_eth_tx_bytes) != 1)
251 				do_error("fscanf fails", false);
252 			fclose(fin);
253 			printf("  new_eth_tx_bytes:%llu\n",
254 			       new_eth_tx_bytes);
255 			bpf_map_lookup_elem(queue_stats_fd, &key, &qstats);
256 			new_cg_tx_bytes = qstats.bytes_total;
257 			delta_bytes = new_eth_tx_bytes - last_eth_tx_bytes;
258 			last_eth_tx_bytes = new_eth_tx_bytes;
259 			delta_rate = (delta_bytes * 8000000) / delta_time;
260 			printf("%5d - eth_rate:%.1fGbps cg_rate:%.3fGbps",
261 			       delta_ms, delta_rate/1000000000.0,
262 			       rate/1000.0);
263 			if (delta_rate < RATE_THRESHOLD) {
264 				/* can increase cgroup rate limit, but first
265 				 * check if we are using the current limit.
266 				 * Currently increasing by 6.25%, unknown
267 				 * if that is the optimal rate.
268 				 */
269 				int rate_diff100;
270 
271 				delta_bytes = new_cg_tx_bytes -
272 					last_cg_tx_bytes;
273 				last_cg_tx_bytes = new_cg_tx_bytes;
274 				delta_rate = (delta_bytes * 8000000) /
275 					delta_time;
276 				printf(" rate:%.3fGbps",
277 				       delta_rate/1000000000.0);
278 				rate_diff100 = (((long long)rate)*1000000 -
279 						     delta_rate) * 100 /
280 					(((long long) rate) * 1000000);
281 				printf("  rdiff:%d", rate_diff100);
282 				if (rate_diff100  <= 3) {
283 					rate += (rate >> 4);
284 					if (rate > RATE_THRESHOLD / 1000000)
285 						rate = RATE_THRESHOLD / 1000000;
286 					qstats.rate = rate;
287 					printf(" INC\n");
288 				} else {
289 					printf("\n");
290 				}
291 			} else {
292 				/* Need to decrease cgroup rate limit.
293 				 * Currently decreasing by 12.5%, unknown
294 				 * if that is optimal
295 				 */
296 				printf(" DEC\n");
297 				rate -= (rate >> 3);
298 				if (rate < minRate)
299 					rate = minRate;
300 				qstats.rate = rate;
301 			}
302 			if (bpf_map_update_elem(queue_stats_fd, &key, &qstats, BPF_ANY))
303 				do_error("update map element fails", false);
304 		}
305 	} else {
306 		sleep(dur);
307 	}
308 	// Get stats!
309 	if (stats_flag && bpf_map_lookup_elem(queue_stats_fd, &key, &qstats)) {
310 		char fname[100];
311 		FILE *fout;
312 
313 		if (!outFlag)
314 			sprintf(fname, "hbm.%d.in", cg_id);
315 		else
316 			sprintf(fname, "hbm.%d.out", cg_id);
317 		fout = fopen(fname, "w");
318 		fprintf(fout, "id:%d\n", cg_id);
319 		fprintf(fout, "ERROR: Could not lookup queue_stats\n");
320 	} else if (stats_flag && qstats.lastPacketTime >
321 		   qstats.firstPacketTime) {
322 		long long delta_us = (qstats.lastPacketTime -
323 				      qstats.firstPacketTime)/1000;
324 		unsigned int rate_mbps = ((qstats.bytes_total -
325 					   qstats.bytes_dropped) * 8 /
326 					  delta_us);
327 		double percent_pkts, percent_bytes;
328 		char fname[100];
329 		FILE *fout;
330 		int k;
331 		static const char *returnValNames[] = {
332 			"DROP_PKT",
333 			"ALLOW_PKT",
334 			"DROP_PKT_CWR",
335 			"ALLOW_PKT_CWR"
336 		};
337 #define RET_VAL_COUNT 4
338 
339 // Future support of ingress
340 //		if (!outFlag)
341 //			sprintf(fname, "hbm.%d.in", cg_id);
342 //		else
343 		sprintf(fname, "hbm.%d.out", cg_id);
344 		fout = fopen(fname, "w");
345 		fprintf(fout, "id:%d\n", cg_id);
346 		fprintf(fout, "rate_mbps:%d\n", rate_mbps);
347 		fprintf(fout, "duration:%.1f secs\n",
348 			(qstats.lastPacketTime - qstats.firstPacketTime) /
349 			1000000000.0);
350 		fprintf(fout, "packets:%d\n", (int)qstats.pkts_total);
351 		fprintf(fout, "bytes_MB:%d\n", (int)(qstats.bytes_total /
352 						     1000000));
353 		fprintf(fout, "pkts_dropped:%d\n", (int)qstats.pkts_dropped);
354 		fprintf(fout, "bytes_dropped_MB:%d\n",
355 			(int)(qstats.bytes_dropped /
356 						       1000000));
357 		// Marked Pkts and Bytes
358 		percent_pkts = (qstats.pkts_marked * 100.0) /
359 			(qstats.pkts_total + 1);
360 		percent_bytes = (qstats.bytes_marked * 100.0) /
361 			(qstats.bytes_total + 1);
362 		fprintf(fout, "pkts_marked_percent:%6.2f\n", percent_pkts);
363 		fprintf(fout, "bytes_marked_percent:%6.2f\n", percent_bytes);
364 
365 		// Dropped Pkts and Bytes
366 		percent_pkts = (qstats.pkts_dropped * 100.0) /
367 			(qstats.pkts_total + 1);
368 		percent_bytes = (qstats.bytes_dropped * 100.0) /
369 			(qstats.bytes_total + 1);
370 		fprintf(fout, "pkts_dropped_percent:%6.2f\n", percent_pkts);
371 		fprintf(fout, "bytes_dropped_percent:%6.2f\n", percent_bytes);
372 
373 		// ECN CE markings
374 		percent_pkts = (qstats.pkts_ecn_ce * 100.0) /
375 			(qstats.pkts_total + 1);
376 		fprintf(fout, "pkts_ecn_ce:%6.2f (%d)\n", percent_pkts,
377 			(int)qstats.pkts_ecn_ce);
378 
379 		// Average cwnd
380 		fprintf(fout, "avg cwnd:%d\n",
381 			(int)(qstats.sum_cwnd / (qstats.sum_cwnd_cnt + 1)));
382 		// Average rtt
383 		fprintf(fout, "avg rtt:%d\n",
384 			(int)(qstats.sum_rtt / (qstats.pkts_total + 1)));
385 		// Average credit
386 		if (edt_flag)
387 			fprintf(fout, "avg credit_ms:%.03f\n",
388 				(qstats.sum_credit /
389 				 (qstats.pkts_total + 1.0)) / 1000000.0);
390 		else
391 			fprintf(fout, "avg credit:%d\n",
392 				(int)(qstats.sum_credit /
393 				      (1500 * ((int)qstats.pkts_total ) + 1)));
394 
395 		// Return values stats
396 		for (k = 0; k < RET_VAL_COUNT; k++) {
397 			percent_pkts = (qstats.returnValCount[k] * 100.0) /
398 				(qstats.pkts_total + 1);
399 			fprintf(fout, "%s:%6.2f (%d)\n", returnValNames[k],
400 				percent_pkts, (int)qstats.returnValCount[k]);
401 		}
402 		fclose(fout);
403 	}
404 
405 	if (debugFlag)
406 		read_trace_pipe2();
407 	goto cleanup;
408 
409 err:
410 	rc = 1;
411 
412 cleanup:
413 	bpf_link__destroy(link);
414 	bpf_object__close(obj);
415 
416 	if (cg1 != -1)
417 		close(cg1);
418 
419 	if (rc != 0)
420 		cleanup_cgroup_environment();
421 	return rc;
422 }
423 
424 static void Usage(void)
425 {
426 	printf("This program loads a cgroup skb BPF program to enforce\n"
427 	       "cgroup output (egress) bandwidth limits.\n\n"
428 	       "USAGE: hbm [-o] [-d]  [-l] [-n <id>] [--no_cn] [-r <rate>]\n"
429 	       "           [-s] [-t <secs>] [-w] [-h] [prog]\n"
430 	       "  Where:\n"
431 	       "    -o         indicates egress direction (default)\n"
432 	       "    -d         print BPF trace debug buffer\n"
433 	       "    --edt      use fq's Earliest Departure Time\n"
434 	       "    -l         also limit flows using loopback\n"
435 	       "    -n <#>     to create cgroup \"/hbm#\" and attach prog\n"
436 	       "               Default is /hbm1\n"
437 	       "    --no_cn    disable CN notifications\n"
438 	       "    -r <rate>  Rate in Mbps\n"
439 	       "    -s         Update HBM stats\n"
440 	       "    -t <time>  Exit after specified seconds (default is 0)\n"
441 	       "    -w	       Work conserving flag. cgroup can increase\n"
442 	       "               bandwidth beyond the rate limit specified\n"
443 	       "               while there is available bandwidth. Current\n"
444 	       "               implementation assumes there is only eth0\n"
445 	       "               but can be extended to support multiple NICs\n"
446 	       "    -h         print this info\n"
447 	       "    prog       BPF program file name. Name defaults to\n"
448 	       "                 hbm_out_kern.o\n");
449 }
450 
451 int main(int argc, char **argv)
452 {
453 	char *prog = "hbm_out_kern.o";
454 	int  k;
455 	int cg_id = 1;
456 	char *optstring = "iodln:r:st:wh";
457 	struct option loptions[] = {
458 		{"no_cn", 0, NULL, 1},
459 		{"edt", 0, NULL, 2},
460 		{NULL, 0, NULL, 0}
461 	};
462 
463 	while ((k = getopt_long(argc, argv, optstring, loptions, NULL)) != -1) {
464 		switch (k) {
465 		case 1:
466 			no_cn_flag = true;
467 			break;
468 		case 2:
469 			prog = "hbm_edt_kern.o";
470 			edt_flag = true;
471 			break;
472 		case'o':
473 			break;
474 		case 'd':
475 			debugFlag = true;
476 			break;
477 		case 'l':
478 			loopback_flag = true;
479 			break;
480 		case 'n':
481 			cg_id = atoi(optarg);
482 			break;
483 		case 'r':
484 			minRate = atoi(optarg) * 1.024;
485 			rate = minRate;
486 			break;
487 		case 's':
488 			stats_flag = true;
489 			break;
490 		case 't':
491 			dur = atoi(optarg);
492 			break;
493 		case 'w':
494 			work_conserving_flag = true;
495 			break;
496 		case '?':
497 			if (optopt == 'n' || optopt == 'r' || optopt == 't')
498 				fprintf(stderr,
499 					"Option -%c requires an argument.\n\n",
500 					optopt);
501 		case 'h':
502 			__fallthrough;
503 		default:
504 			Usage();
505 			return 0;
506 		}
507 	}
508 
509 	if (optind < argc)
510 		prog = argv[optind];
511 	printf("HBM prog: %s\n", prog != NULL ? prog : "NULL");
512 
513 	return run_bpf_prog(prog, cg_id);
514 }
515