xref: /linux/net/mac80211/rate.c (revision 9208c05f9fdfd927ea160b97dfef3c379049fff2)
1 // SPDX-License-Identifier: GPL-2.0-only
2 /*
3  * Copyright 2002-2005, Instant802 Networks, Inc.
4  * Copyright 2005-2006, Devicescape Software, Inc.
5  * Copyright (c) 2006 Jiri Benc <jbenc@suse.cz>
6  * Copyright 2017	Intel Deutschland GmbH
7  * Copyright (C) 2019, 2022-2024 Intel Corporation
8  */
9 
10 #include <linux/kernel.h>
11 #include <linux/rtnetlink.h>
12 #include <linux/module.h>
13 #include <linux/slab.h>
14 #include "rate.h"
15 #include "ieee80211_i.h"
16 #include "debugfs.h"
17 
18 struct rate_control_alg {
19 	struct list_head list;
20 	const struct rate_control_ops *ops;
21 };
22 
23 static LIST_HEAD(rate_ctrl_algs);
24 static DEFINE_MUTEX(rate_ctrl_mutex);
25 
26 static char *ieee80211_default_rc_algo = CONFIG_MAC80211_RC_DEFAULT;
27 module_param(ieee80211_default_rc_algo, charp, 0644);
28 MODULE_PARM_DESC(ieee80211_default_rc_algo,
29 		 "Default rate control algorithm for mac80211 to use");
30 
31 void rate_control_rate_init(struct link_sta_info *link_sta)
32 {
33 	struct sta_info *sta = link_sta->sta;
34 	struct ieee80211_local *local = sta->sdata->local;
35 	struct rate_control_ref *ref = sta->rate_ctrl;
36 	struct ieee80211_sta *ista = &sta->sta;
37 	void *priv_sta = sta->rate_ctrl_priv;
38 	struct ieee80211_supported_band *sband;
39 	struct ieee80211_chanctx_conf *chanctx_conf;
40 
41 	ieee80211_sta_init_nss(link_sta);
42 
43 	if (!ref)
44 		return;
45 
46 	/* SW rate control isn't supported with MLO right now */
47 	if (WARN_ON(ieee80211_vif_is_mld(&sta->sdata->vif)))
48 		return;
49 
50 	rcu_read_lock();
51 
52 	chanctx_conf = rcu_dereference(sta->sdata->vif.bss_conf.chanctx_conf);
53 	if (WARN_ON(!chanctx_conf)) {
54 		rcu_read_unlock();
55 		return;
56 	}
57 
58 	sband = local->hw.wiphy->bands[chanctx_conf->def.chan->band];
59 
60 	/* TODO: check for minstrel_s1g ? */
61 	if (sband->band == NL80211_BAND_S1GHZ) {
62 		ieee80211_s1g_sta_rate_init(sta);
63 		rcu_read_unlock();
64 		return;
65 	}
66 
67 	spin_lock_bh(&sta->rate_ctrl_lock);
68 	ref->ops->rate_init(ref->priv, sband, &chanctx_conf->def, ista,
69 			    priv_sta);
70 	spin_unlock_bh(&sta->rate_ctrl_lock);
71 	rcu_read_unlock();
72 	set_sta_flag(sta, WLAN_STA_RATE_CONTROL);
73 }
74 
75 void rate_control_rate_init_all_links(struct sta_info *sta)
76 {
77 	int link_id;
78 
79 	for (link_id = 0; link_id < ARRAY_SIZE(sta->link); link_id++) {
80 		struct link_sta_info *link_sta;
81 
82 		link_sta = sdata_dereference(sta->link[link_id], sta->sdata);
83 		if (!link_sta)
84 			continue;
85 
86 		rate_control_rate_init(link_sta);
87 	}
88 }
89 
90 void rate_control_tx_status(struct ieee80211_local *local,
91 			    struct ieee80211_tx_status *st)
92 {
93 	struct rate_control_ref *ref = local->rate_ctrl;
94 	struct sta_info *sta = container_of(st->sta, struct sta_info, sta);
95 	void *priv_sta = sta->rate_ctrl_priv;
96 	struct ieee80211_supported_band *sband;
97 
98 	if (!ref || !test_sta_flag(sta, WLAN_STA_RATE_CONTROL))
99 		return;
100 
101 	sband = local->hw.wiphy->bands[st->info->band];
102 
103 	spin_lock_bh(&sta->rate_ctrl_lock);
104 	if (ref->ops->tx_status_ext)
105 		ref->ops->tx_status_ext(ref->priv, sband, priv_sta, st);
106 	else if (st->skb)
107 		ref->ops->tx_status(ref->priv, sband, st->sta, priv_sta, st->skb);
108 	else
109 		WARN_ON_ONCE(1);
110 
111 	spin_unlock_bh(&sta->rate_ctrl_lock);
112 }
113 
114 void rate_control_rate_update(struct ieee80211_local *local,
115 			      struct ieee80211_supported_band *sband,
116 			      struct link_sta_info *link_sta,
117 			      u32 changed)
118 {
119 	struct rate_control_ref *ref = local->rate_ctrl;
120 	struct sta_info *sta = link_sta->sta;
121 	struct ieee80211_sta *ista = &sta->sta;
122 	void *priv_sta = sta->rate_ctrl_priv;
123 	struct ieee80211_chanctx_conf *chanctx_conf;
124 
125 	if (ref && ref->ops->rate_update) {
126 		rcu_read_lock();
127 
128 		chanctx_conf = rcu_dereference(sta->sdata->vif.bss_conf.chanctx_conf);
129 		if (WARN_ON(!chanctx_conf)) {
130 			rcu_read_unlock();
131 			return;
132 		}
133 
134 		spin_lock_bh(&sta->rate_ctrl_lock);
135 		ref->ops->rate_update(ref->priv, sband, &chanctx_conf->def,
136 				      ista, priv_sta, changed);
137 		spin_unlock_bh(&sta->rate_ctrl_lock);
138 		rcu_read_unlock();
139 	}
140 
141 	if (sta->uploaded)
142 		drv_link_sta_rc_update(local, sta->sdata, link_sta->pub,
143 				       changed);
144 }
145 
146 int ieee80211_rate_control_register(const struct rate_control_ops *ops)
147 {
148 	struct rate_control_alg *alg;
149 
150 	if (!ops->name)
151 		return -EINVAL;
152 
153 	mutex_lock(&rate_ctrl_mutex);
154 	list_for_each_entry(alg, &rate_ctrl_algs, list) {
155 		if (!strcmp(alg->ops->name, ops->name)) {
156 			/* don't register an algorithm twice */
157 			WARN_ON(1);
158 			mutex_unlock(&rate_ctrl_mutex);
159 			return -EALREADY;
160 		}
161 	}
162 
163 	alg = kzalloc(sizeof(*alg), GFP_KERNEL);
164 	if (alg == NULL) {
165 		mutex_unlock(&rate_ctrl_mutex);
166 		return -ENOMEM;
167 	}
168 	alg->ops = ops;
169 
170 	list_add_tail(&alg->list, &rate_ctrl_algs);
171 	mutex_unlock(&rate_ctrl_mutex);
172 
173 	return 0;
174 }
175 EXPORT_SYMBOL(ieee80211_rate_control_register);
176 
177 void ieee80211_rate_control_unregister(const struct rate_control_ops *ops)
178 {
179 	struct rate_control_alg *alg;
180 
181 	mutex_lock(&rate_ctrl_mutex);
182 	list_for_each_entry(alg, &rate_ctrl_algs, list) {
183 		if (alg->ops == ops) {
184 			list_del(&alg->list);
185 			kfree(alg);
186 			break;
187 		}
188 	}
189 	mutex_unlock(&rate_ctrl_mutex);
190 }
191 EXPORT_SYMBOL(ieee80211_rate_control_unregister);
192 
193 static const struct rate_control_ops *
194 ieee80211_try_rate_control_ops_get(const char *name)
195 {
196 	struct rate_control_alg *alg;
197 	const struct rate_control_ops *ops = NULL;
198 
199 	if (!name)
200 		return NULL;
201 
202 	mutex_lock(&rate_ctrl_mutex);
203 	list_for_each_entry(alg, &rate_ctrl_algs, list) {
204 		if (!strcmp(alg->ops->name, name)) {
205 			ops = alg->ops;
206 			break;
207 		}
208 	}
209 	mutex_unlock(&rate_ctrl_mutex);
210 	return ops;
211 }
212 
213 /* Get the rate control algorithm. */
214 static const struct rate_control_ops *
215 ieee80211_rate_control_ops_get(const char *name)
216 {
217 	const struct rate_control_ops *ops;
218 	const char *alg_name;
219 
220 	kernel_param_lock(THIS_MODULE);
221 	if (!name)
222 		alg_name = ieee80211_default_rc_algo;
223 	else
224 		alg_name = name;
225 
226 	ops = ieee80211_try_rate_control_ops_get(alg_name);
227 	if (!ops && name)
228 		/* try default if specific alg requested but not found */
229 		ops = ieee80211_try_rate_control_ops_get(ieee80211_default_rc_algo);
230 
231 	/* Note: check for > 0 is intentional to avoid clang warning */
232 	if (!ops && (strlen(CONFIG_MAC80211_RC_DEFAULT) > 0))
233 		/* try built-in one if specific alg requested but not found */
234 		ops = ieee80211_try_rate_control_ops_get(CONFIG_MAC80211_RC_DEFAULT);
235 
236 	kernel_param_unlock(THIS_MODULE);
237 
238 	return ops;
239 }
240 
241 #ifdef CONFIG_MAC80211_DEBUGFS
242 static ssize_t rcname_read(struct file *file, char __user *userbuf,
243 			   size_t count, loff_t *ppos)
244 {
245 	struct rate_control_ref *ref = file->private_data;
246 	int len = strlen(ref->ops->name);
247 
248 	return simple_read_from_buffer(userbuf, count, ppos,
249 				       ref->ops->name, len);
250 }
251 
252 const struct debugfs_short_fops rcname_ops = {
253 	.read = rcname_read,
254 	.llseek = default_llseek,
255 };
256 #endif
257 
258 static struct rate_control_ref *
259 rate_control_alloc(const char *name, struct ieee80211_local *local)
260 {
261 	struct rate_control_ref *ref;
262 
263 	ref = kmalloc(sizeof(struct rate_control_ref), GFP_KERNEL);
264 	if (!ref)
265 		return NULL;
266 	ref->ops = ieee80211_rate_control_ops_get(name);
267 	if (!ref->ops)
268 		goto free;
269 
270 	ref->priv = ref->ops->alloc(&local->hw);
271 	if (!ref->priv)
272 		goto free;
273 	return ref;
274 
275 free:
276 	kfree(ref);
277 	return NULL;
278 }
279 
280 static void rate_control_free(struct ieee80211_local *local,
281 			      struct rate_control_ref *ctrl_ref)
282 {
283 	ctrl_ref->ops->free(ctrl_ref->priv);
284 
285 #ifdef CONFIG_MAC80211_DEBUGFS
286 	debugfs_remove_recursive(local->debugfs.rcdir);
287 	local->debugfs.rcdir = NULL;
288 #endif
289 
290 	kfree(ctrl_ref);
291 }
292 
293 void ieee80211_check_rate_mask(struct ieee80211_link_data *link)
294 {
295 	struct ieee80211_sub_if_data *sdata = link->sdata;
296 	struct ieee80211_local *local = sdata->local;
297 	struct ieee80211_supported_band *sband;
298 	u32 user_mask, basic_rates = link->conf->basic_rates;
299 	enum nl80211_band band;
300 
301 	if (WARN_ON(!link->conf->chanreq.oper.chan))
302 		return;
303 
304 	band = link->conf->chanreq.oper.chan->band;
305 	if (band == NL80211_BAND_S1GHZ) {
306 		/* TODO */
307 		return;
308 	}
309 
310 	if (WARN_ON_ONCE(!basic_rates))
311 		return;
312 
313 	user_mask = sdata->rc_rateidx_mask[band];
314 	sband = local->hw.wiphy->bands[band];
315 
316 	if (user_mask & basic_rates)
317 		return;
318 
319 	sdata_dbg(sdata,
320 		  "no overlap between basic rates (0x%x) and user mask (0x%x on band %d) - clearing the latter",
321 		  basic_rates, user_mask, band);
322 	sdata->rc_rateidx_mask[band] = (1 << sband->n_bitrates) - 1;
323 }
324 
325 static bool rc_no_data_or_no_ack_use_min(struct ieee80211_tx_rate_control *txrc)
326 {
327 	struct sk_buff *skb = txrc->skb;
328 	struct ieee80211_tx_info *info = IEEE80211_SKB_CB(skb);
329 
330 	return (info->flags & (IEEE80211_TX_CTL_NO_ACK |
331 			       IEEE80211_TX_CTL_USE_MINRATE)) ||
332 		!ieee80211_is_tx_data(skb);
333 }
334 
335 static void rc_send_low_basicrate(struct ieee80211_tx_rate *rate,
336 				  u32 basic_rates,
337 				  struct ieee80211_supported_band *sband)
338 {
339 	u8 i;
340 
341 	if (sband->band == NL80211_BAND_S1GHZ) {
342 		/* TODO */
343 		rate->flags |= IEEE80211_TX_RC_S1G_MCS;
344 		rate->idx = 0;
345 		return;
346 	}
347 
348 	if (basic_rates == 0)
349 		return; /* assume basic rates unknown and accept rate */
350 	if (rate->idx < 0)
351 		return;
352 	if (basic_rates & (1 << rate->idx))
353 		return; /* selected rate is a basic rate */
354 
355 	for (i = rate->idx + 1; i <= sband->n_bitrates; i++) {
356 		if (basic_rates & (1 << i)) {
357 			rate->idx = i;
358 			return;
359 		}
360 	}
361 
362 	/* could not find a basic rate; use original selection */
363 }
364 
365 static void __rate_control_send_low(struct ieee80211_hw *hw,
366 				    struct ieee80211_supported_band *sband,
367 				    struct ieee80211_sta *sta,
368 				    struct ieee80211_tx_info *info,
369 				    u32 rate_mask)
370 {
371 	int i;
372 	u32 rate_flags =
373 		ieee80211_chandef_rate_flags(&hw->conf.chandef);
374 
375 	if (sband->band == NL80211_BAND_S1GHZ) {
376 		info->control.rates[0].flags |= IEEE80211_TX_RC_S1G_MCS;
377 		info->control.rates[0].idx = 0;
378 		return;
379 	}
380 
381 	if ((sband->band == NL80211_BAND_2GHZ) &&
382 	    (info->flags & IEEE80211_TX_CTL_NO_CCK_RATE))
383 		rate_flags |= IEEE80211_RATE_ERP_G;
384 
385 	info->control.rates[0].idx = 0;
386 	for (i = 0; i < sband->n_bitrates; i++) {
387 		if (!(rate_mask & BIT(i)))
388 			continue;
389 
390 		if ((rate_flags & sband->bitrates[i].flags) != rate_flags)
391 			continue;
392 
393 		if (!rate_supported(sta, sband->band, i))
394 			continue;
395 
396 		info->control.rates[0].idx = i;
397 		break;
398 	}
399 	WARN_ONCE(i == sband->n_bitrates,
400 		  "no supported rates for sta %pM (0x%x, band %d) in rate_mask 0x%x with flags 0x%x\n",
401 		  sta ? sta->addr : NULL,
402 		  sta ? sta->deflink.supp_rates[sband->band] : -1,
403 		  sband->band,
404 		  rate_mask, rate_flags);
405 
406 	info->control.rates[0].count =
407 		(info->flags & IEEE80211_TX_CTL_NO_ACK) ?
408 		1 : hw->max_rate_tries;
409 
410 	info->control.skip_table = 1;
411 }
412 
413 
414 static bool rate_control_send_low(struct ieee80211_sta *pubsta,
415 				  struct ieee80211_tx_rate_control *txrc)
416 {
417 	struct ieee80211_tx_info *info = IEEE80211_SKB_CB(txrc->skb);
418 	struct ieee80211_supported_band *sband = txrc->sband;
419 	struct sta_info *sta;
420 	int mcast_rate;
421 	bool use_basicrate = false;
422 
423 	if (!pubsta || rc_no_data_or_no_ack_use_min(txrc)) {
424 		__rate_control_send_low(txrc->hw, sband, pubsta, info,
425 					txrc->rate_idx_mask);
426 
427 		if (!pubsta && txrc->bss) {
428 			mcast_rate = txrc->bss_conf->mcast_rate[sband->band];
429 			if (mcast_rate > 0) {
430 				info->control.rates[0].idx = mcast_rate - 1;
431 				return true;
432 			}
433 			use_basicrate = true;
434 		} else if (pubsta) {
435 			sta = container_of(pubsta, struct sta_info, sta);
436 			if (ieee80211_vif_is_mesh(&sta->sdata->vif))
437 				use_basicrate = true;
438 		}
439 
440 		if (use_basicrate)
441 			rc_send_low_basicrate(&info->control.rates[0],
442 					      txrc->bss_conf->basic_rates,
443 					      sband);
444 
445 		return true;
446 	}
447 	return false;
448 }
449 
450 static bool rate_idx_match_legacy_mask(s8 *rate_idx, int n_bitrates, u32 mask)
451 {
452 	int j;
453 
454 	/* See whether the selected rate or anything below it is allowed. */
455 	for (j = *rate_idx; j >= 0; j--) {
456 		if (mask & (1 << j)) {
457 			/* Okay, found a suitable rate. Use it. */
458 			*rate_idx = j;
459 			return true;
460 		}
461 	}
462 
463 	/* Try to find a higher rate that would be allowed */
464 	for (j = *rate_idx + 1; j < n_bitrates; j++) {
465 		if (mask & (1 << j)) {
466 			/* Okay, found a suitable rate. Use it. */
467 			*rate_idx = j;
468 			return true;
469 		}
470 	}
471 	return false;
472 }
473 
474 static bool rate_idx_match_mcs_mask(s8 *rate_idx, u8 *mcs_mask)
475 {
476 	int i, j;
477 	int ridx, rbit;
478 
479 	ridx = *rate_idx / 8;
480 	rbit = *rate_idx % 8;
481 
482 	/* sanity check */
483 	if (ridx < 0 || ridx >= IEEE80211_HT_MCS_MASK_LEN)
484 		return false;
485 
486 	/* See whether the selected rate or anything below it is allowed. */
487 	for (i = ridx; i >= 0; i--) {
488 		for (j = rbit; j >= 0; j--)
489 			if (mcs_mask[i] & BIT(j)) {
490 				*rate_idx = i * 8 + j;
491 				return true;
492 			}
493 		rbit = 7;
494 	}
495 
496 	/* Try to find a higher rate that would be allowed */
497 	ridx = (*rate_idx + 1) / 8;
498 	rbit = (*rate_idx + 1) % 8;
499 
500 	for (i = ridx; i < IEEE80211_HT_MCS_MASK_LEN; i++) {
501 		for (j = rbit; j < 8; j++)
502 			if (mcs_mask[i] & BIT(j)) {
503 				*rate_idx = i * 8 + j;
504 				return true;
505 			}
506 		rbit = 0;
507 	}
508 	return false;
509 }
510 
511 static bool rate_idx_match_vht_mcs_mask(s8 *rate_idx, u16 *vht_mask)
512 {
513 	int i, j;
514 	int ridx, rbit;
515 
516 	ridx = *rate_idx >> 4;
517 	rbit = *rate_idx & 0xf;
518 
519 	if (ridx < 0 || ridx >= NL80211_VHT_NSS_MAX)
520 		return false;
521 
522 	/* See whether the selected rate or anything below it is allowed. */
523 	for (i = ridx; i >= 0; i--) {
524 		for (j = rbit; j >= 0; j--) {
525 			if (vht_mask[i] & BIT(j)) {
526 				*rate_idx = (i << 4) | j;
527 				return true;
528 			}
529 		}
530 		rbit = 15;
531 	}
532 
533 	/* Try to find a higher rate that would be allowed */
534 	ridx = (*rate_idx + 1) >> 4;
535 	rbit = (*rate_idx + 1) & 0xf;
536 
537 	for (i = ridx; i < NL80211_VHT_NSS_MAX; i++) {
538 		for (j = rbit; j < 16; j++) {
539 			if (vht_mask[i] & BIT(j)) {
540 				*rate_idx = (i << 4) | j;
541 				return true;
542 			}
543 		}
544 		rbit = 0;
545 	}
546 	return false;
547 }
548 
549 static void rate_idx_match_mask(s8 *rate_idx, u16 *rate_flags,
550 				struct ieee80211_supported_band *sband,
551 				enum nl80211_chan_width chan_width,
552 				u32 mask,
553 				u8 mcs_mask[IEEE80211_HT_MCS_MASK_LEN],
554 				u16 vht_mask[NL80211_VHT_NSS_MAX])
555 {
556 	if (*rate_flags & IEEE80211_TX_RC_VHT_MCS) {
557 		/* handle VHT rates */
558 		if (rate_idx_match_vht_mcs_mask(rate_idx, vht_mask))
559 			return;
560 
561 		*rate_idx = 0;
562 		/* keep protection flags */
563 		*rate_flags &= (IEEE80211_TX_RC_USE_RTS_CTS |
564 				IEEE80211_TX_RC_USE_CTS_PROTECT |
565 				IEEE80211_TX_RC_USE_SHORT_PREAMBLE);
566 
567 		*rate_flags |= IEEE80211_TX_RC_MCS;
568 		if (chan_width == NL80211_CHAN_WIDTH_40)
569 			*rate_flags |= IEEE80211_TX_RC_40_MHZ_WIDTH;
570 
571 		if (rate_idx_match_mcs_mask(rate_idx, mcs_mask))
572 			return;
573 
574 		/* also try the legacy rates. */
575 		*rate_flags &= ~(IEEE80211_TX_RC_MCS |
576 				 IEEE80211_TX_RC_40_MHZ_WIDTH);
577 		if (rate_idx_match_legacy_mask(rate_idx, sband->n_bitrates,
578 					       mask))
579 			return;
580 	} else if (*rate_flags & IEEE80211_TX_RC_MCS) {
581 		/* handle HT rates */
582 		if (rate_idx_match_mcs_mask(rate_idx, mcs_mask))
583 			return;
584 
585 		/* also try the legacy rates. */
586 		*rate_idx = 0;
587 		/* keep protection flags */
588 		*rate_flags &= (IEEE80211_TX_RC_USE_RTS_CTS |
589 				IEEE80211_TX_RC_USE_CTS_PROTECT |
590 				IEEE80211_TX_RC_USE_SHORT_PREAMBLE);
591 		if (rate_idx_match_legacy_mask(rate_idx, sband->n_bitrates,
592 					       mask))
593 			return;
594 	} else {
595 		/* handle legacy rates */
596 		if (rate_idx_match_legacy_mask(rate_idx, sband->n_bitrates,
597 					       mask))
598 			return;
599 
600 		/* if HT BSS, and we handle a data frame, also try HT rates */
601 		switch (chan_width) {
602 		case NL80211_CHAN_WIDTH_20_NOHT:
603 		case NL80211_CHAN_WIDTH_5:
604 		case NL80211_CHAN_WIDTH_10:
605 			return;
606 		default:
607 			break;
608 		}
609 
610 		*rate_idx = 0;
611 		/* keep protection flags */
612 		*rate_flags &= (IEEE80211_TX_RC_USE_RTS_CTS |
613 				IEEE80211_TX_RC_USE_CTS_PROTECT |
614 				IEEE80211_TX_RC_USE_SHORT_PREAMBLE);
615 
616 		*rate_flags |= IEEE80211_TX_RC_MCS;
617 
618 		if (chan_width == NL80211_CHAN_WIDTH_40)
619 			*rate_flags |= IEEE80211_TX_RC_40_MHZ_WIDTH;
620 
621 		if (rate_idx_match_mcs_mask(rate_idx, mcs_mask))
622 			return;
623 	}
624 
625 	/*
626 	 * Uh.. No suitable rate exists. This should not really happen with
627 	 * sane TX rate mask configurations. However, should someone manage to
628 	 * configure supported rates and TX rate mask in incompatible way,
629 	 * allow the frame to be transmitted with whatever the rate control
630 	 * selected.
631 	 */
632 }
633 
634 static void rate_fixup_ratelist(struct ieee80211_vif *vif,
635 				struct ieee80211_supported_band *sband,
636 				struct ieee80211_tx_info *info,
637 				struct ieee80211_tx_rate *rates,
638 				int max_rates)
639 {
640 	struct ieee80211_rate *rate;
641 	bool inval = false;
642 	int i;
643 
644 	/*
645 	 * Set up the RTS/CTS rate as the fastest basic rate
646 	 * that is not faster than the data rate unless there
647 	 * is no basic rate slower than the data rate, in which
648 	 * case we pick the slowest basic rate
649 	 *
650 	 * XXX: Should this check all retry rates?
651 	 */
652 	if (!(rates[0].flags &
653 	      (IEEE80211_TX_RC_MCS | IEEE80211_TX_RC_VHT_MCS))) {
654 		u32 basic_rates = vif->bss_conf.basic_rates;
655 		s8 baserate = basic_rates ? ffs(basic_rates) - 1 : 0;
656 
657 		rate = &sband->bitrates[rates[0].idx];
658 
659 		for (i = 0; i < sband->n_bitrates; i++) {
660 			/* must be a basic rate */
661 			if (!(basic_rates & BIT(i)))
662 				continue;
663 			/* must not be faster than the data rate */
664 			if (sband->bitrates[i].bitrate > rate->bitrate)
665 				continue;
666 			/* maximum */
667 			if (sband->bitrates[baserate].bitrate <
668 			     sband->bitrates[i].bitrate)
669 				baserate = i;
670 		}
671 
672 		info->control.rts_cts_rate_idx = baserate;
673 	}
674 
675 	for (i = 0; i < max_rates; i++) {
676 		/*
677 		 * make sure there's no valid rate following
678 		 * an invalid one, just in case drivers don't
679 		 * take the API seriously to stop at -1.
680 		 */
681 		if (inval) {
682 			rates[i].idx = -1;
683 			continue;
684 		}
685 		if (rates[i].idx < 0) {
686 			inval = true;
687 			continue;
688 		}
689 
690 		/*
691 		 * For now assume MCS is already set up correctly, this
692 		 * needs to be fixed.
693 		 */
694 		if (rates[i].flags & IEEE80211_TX_RC_MCS) {
695 			WARN_ON(rates[i].idx > 76);
696 
697 			if (!(rates[i].flags & IEEE80211_TX_RC_USE_RTS_CTS) &&
698 			    info->control.use_cts_prot)
699 				rates[i].flags |=
700 					IEEE80211_TX_RC_USE_CTS_PROTECT;
701 			continue;
702 		}
703 
704 		if (rates[i].flags & IEEE80211_TX_RC_VHT_MCS) {
705 			WARN_ON(ieee80211_rate_get_vht_mcs(&rates[i]) > 9);
706 			continue;
707 		}
708 
709 		/* set up RTS protection if desired */
710 		if (info->control.use_rts) {
711 			rates[i].flags |= IEEE80211_TX_RC_USE_RTS_CTS;
712 			info->control.use_cts_prot = false;
713 		}
714 
715 		/* RC is busted */
716 		if (WARN_ON_ONCE(rates[i].idx >= sband->n_bitrates)) {
717 			rates[i].idx = -1;
718 			continue;
719 		}
720 
721 		rate = &sband->bitrates[rates[i].idx];
722 
723 		/* set up short preamble */
724 		if (info->control.short_preamble &&
725 		    rate->flags & IEEE80211_RATE_SHORT_PREAMBLE)
726 			rates[i].flags |= IEEE80211_TX_RC_USE_SHORT_PREAMBLE;
727 
728 		/* set up G protection */
729 		if (!(rates[i].flags & IEEE80211_TX_RC_USE_RTS_CTS) &&
730 		    info->control.use_cts_prot &&
731 		    rate->flags & IEEE80211_RATE_ERP_G)
732 			rates[i].flags |= IEEE80211_TX_RC_USE_CTS_PROTECT;
733 	}
734 }
735 
736 
737 static void rate_control_fill_sta_table(struct ieee80211_sta *sta,
738 					struct ieee80211_tx_info *info,
739 					struct ieee80211_tx_rate *rates,
740 					int max_rates)
741 {
742 	struct ieee80211_sta_rates *ratetbl = NULL;
743 	int i;
744 
745 	if (sta && !info->control.skip_table)
746 		ratetbl = rcu_dereference(sta->rates);
747 
748 	/* Fill remaining rate slots with data from the sta rate table. */
749 	max_rates = min_t(int, max_rates, IEEE80211_TX_RATE_TABLE_SIZE);
750 	for (i = 0; i < max_rates; i++) {
751 		if (i < ARRAY_SIZE(info->control.rates) &&
752 		    info->control.rates[i].idx >= 0 &&
753 		    info->control.rates[i].count) {
754 			if (rates != info->control.rates)
755 				rates[i] = info->control.rates[i];
756 		} else if (ratetbl) {
757 			rates[i].idx = ratetbl->rate[i].idx;
758 			rates[i].flags = ratetbl->rate[i].flags;
759 			if (info->control.use_rts)
760 				rates[i].count = ratetbl->rate[i].count_rts;
761 			else if (info->control.use_cts_prot)
762 				rates[i].count = ratetbl->rate[i].count_cts;
763 			else
764 				rates[i].count = ratetbl->rate[i].count;
765 		} else {
766 			rates[i].idx = -1;
767 			rates[i].count = 0;
768 		}
769 
770 		if (rates[i].idx < 0 || !rates[i].count)
771 			break;
772 	}
773 }
774 
775 static bool rate_control_cap_mask(struct ieee80211_sub_if_data *sdata,
776 				  struct ieee80211_supported_band *sband,
777 				  struct ieee80211_sta *sta, u32 *mask,
778 				  u8 mcs_mask[IEEE80211_HT_MCS_MASK_LEN],
779 				  u16 vht_mask[NL80211_VHT_NSS_MAX])
780 {
781 	u32 i, flags;
782 
783 	*mask = sdata->rc_rateidx_mask[sband->band];
784 	flags = ieee80211_chandef_rate_flags(&sdata->vif.bss_conf.chanreq.oper);
785 	for (i = 0; i < sband->n_bitrates; i++) {
786 		if ((flags & sband->bitrates[i].flags) != flags)
787 			*mask &= ~BIT(i);
788 	}
789 
790 	if (*mask == (1 << sband->n_bitrates) - 1 &&
791 	    !sdata->rc_has_mcs_mask[sband->band] &&
792 	    !sdata->rc_has_vht_mcs_mask[sband->band])
793 		return false;
794 
795 	if (sdata->rc_has_mcs_mask[sband->band])
796 		memcpy(mcs_mask, sdata->rc_rateidx_mcs_mask[sband->band],
797 		       IEEE80211_HT_MCS_MASK_LEN);
798 	else
799 		memset(mcs_mask, 0xff, IEEE80211_HT_MCS_MASK_LEN);
800 
801 	if (sdata->rc_has_vht_mcs_mask[sband->band])
802 		memcpy(vht_mask, sdata->rc_rateidx_vht_mcs_mask[sband->band],
803 		       sizeof(u16) * NL80211_VHT_NSS_MAX);
804 	else
805 		memset(vht_mask, 0xff, sizeof(u16) * NL80211_VHT_NSS_MAX);
806 
807 	if (sta) {
808 		__le16 sta_vht_cap;
809 		u16 sta_vht_mask[NL80211_VHT_NSS_MAX];
810 
811 		/* Filter out rates that the STA does not support */
812 		*mask &= sta->deflink.supp_rates[sband->band];
813 		for (i = 0; i < IEEE80211_HT_MCS_MASK_LEN; i++)
814 			mcs_mask[i] &= sta->deflink.ht_cap.mcs.rx_mask[i];
815 
816 		sta_vht_cap = sta->deflink.vht_cap.vht_mcs.rx_mcs_map;
817 		ieee80211_get_vht_mask_from_cap(sta_vht_cap, sta_vht_mask);
818 		for (i = 0; i < NL80211_VHT_NSS_MAX; i++)
819 			vht_mask[i] &= sta_vht_mask[i];
820 	}
821 
822 	return true;
823 }
824 
825 static void
826 rate_control_apply_mask_ratetbl(struct sta_info *sta,
827 				struct ieee80211_supported_band *sband,
828 				struct ieee80211_sta_rates *rates)
829 {
830 	int i;
831 	u32 mask;
832 	u8 mcs_mask[IEEE80211_HT_MCS_MASK_LEN];
833 	u16 vht_mask[NL80211_VHT_NSS_MAX];
834 	enum nl80211_chan_width chan_width;
835 
836 	if (!rate_control_cap_mask(sta->sdata, sband, &sta->sta, &mask,
837 				   mcs_mask, vht_mask))
838 		return;
839 
840 	chan_width = sta->sdata->vif.bss_conf.chanreq.oper.width;
841 	for (i = 0; i < IEEE80211_TX_RATE_TABLE_SIZE; i++) {
842 		if (rates->rate[i].idx < 0)
843 			break;
844 
845 		rate_idx_match_mask(&rates->rate[i].idx, &rates->rate[i].flags,
846 				    sband, chan_width, mask, mcs_mask,
847 				    vht_mask);
848 	}
849 }
850 
851 static void rate_control_apply_mask(struct ieee80211_sub_if_data *sdata,
852 				    struct ieee80211_sta *sta,
853 				    struct ieee80211_supported_band *sband,
854 				    struct ieee80211_tx_rate *rates,
855 				    int max_rates)
856 {
857 	enum nl80211_chan_width chan_width;
858 	u8 mcs_mask[IEEE80211_HT_MCS_MASK_LEN];
859 	u32 mask;
860 	u16 rate_flags, vht_mask[NL80211_VHT_NSS_MAX];
861 	int i;
862 
863 	/*
864 	 * Try to enforce the rateidx mask the user wanted. skip this if the
865 	 * default mask (allow all rates) is used to save some processing for
866 	 * the common case.
867 	 */
868 	if (!rate_control_cap_mask(sdata, sband, sta, &mask, mcs_mask,
869 				   vht_mask))
870 		return;
871 
872 	/*
873 	 * Make sure the rate index selected for each TX rate is
874 	 * included in the configured mask and change the rate indexes
875 	 * if needed.
876 	 */
877 	chan_width = sdata->vif.bss_conf.chanreq.oper.width;
878 	for (i = 0; i < max_rates; i++) {
879 		/* Skip invalid rates */
880 		if (rates[i].idx < 0)
881 			break;
882 
883 		rate_flags = rates[i].flags;
884 		rate_idx_match_mask(&rates[i].idx, &rate_flags, sband,
885 				    chan_width, mask, mcs_mask, vht_mask);
886 		rates[i].flags = rate_flags;
887 	}
888 }
889 
890 void ieee80211_get_tx_rates(struct ieee80211_vif *vif,
891 			    struct ieee80211_sta *sta,
892 			    struct sk_buff *skb,
893 			    struct ieee80211_tx_rate *dest,
894 			    int max_rates)
895 {
896 	struct ieee80211_sub_if_data *sdata;
897 	struct ieee80211_tx_info *info = IEEE80211_SKB_CB(skb);
898 	struct ieee80211_supported_band *sband;
899 	u32 mask = ~0;
900 
901 	rate_control_fill_sta_table(sta, info, dest, max_rates);
902 
903 	if (!vif)
904 		return;
905 
906 	sdata = vif_to_sdata(vif);
907 	sband = sdata->local->hw.wiphy->bands[info->band];
908 
909 	if (ieee80211_is_tx_data(skb))
910 		rate_control_apply_mask(sdata, sta, sband, dest, max_rates);
911 
912 	if (!(info->control.flags & IEEE80211_TX_CTRL_DONT_USE_RATE_MASK))
913 		mask = sdata->rc_rateidx_mask[info->band];
914 
915 	if (dest[0].idx < 0)
916 		__rate_control_send_low(&sdata->local->hw, sband, sta, info,
917 					mask);
918 
919 	if (sta)
920 		rate_fixup_ratelist(vif, sband, info, dest, max_rates);
921 }
922 EXPORT_SYMBOL(ieee80211_get_tx_rates);
923 
924 void rate_control_get_rate(struct ieee80211_sub_if_data *sdata,
925 			   struct sta_info *sta,
926 			   struct ieee80211_tx_rate_control *txrc)
927 {
928 	struct rate_control_ref *ref = sdata->local->rate_ctrl;
929 	void *priv_sta = NULL;
930 	struct ieee80211_sta *ista = NULL;
931 	struct ieee80211_tx_info *info = IEEE80211_SKB_CB(txrc->skb);
932 	int i;
933 
934 	for (i = 0; i < IEEE80211_TX_MAX_RATES; i++) {
935 		info->control.rates[i].idx = -1;
936 		info->control.rates[i].flags = 0;
937 		info->control.rates[i].count = 0;
938 	}
939 
940 	if (rate_control_send_low(sta ? &sta->sta : NULL, txrc))
941 		return;
942 
943 	if (ieee80211_hw_check(&sdata->local->hw, HAS_RATE_CONTROL))
944 		return;
945 
946 	if (sta && test_sta_flag(sta, WLAN_STA_RATE_CONTROL)) {
947 		ista = &sta->sta;
948 		priv_sta = sta->rate_ctrl_priv;
949 	}
950 
951 	if (ista) {
952 		spin_lock_bh(&sta->rate_ctrl_lock);
953 		ref->ops->get_rate(ref->priv, ista, priv_sta, txrc);
954 		spin_unlock_bh(&sta->rate_ctrl_lock);
955 	} else {
956 		rate_control_send_low(NULL, txrc);
957 	}
958 
959 	if (ieee80211_hw_check(&sdata->local->hw, SUPPORTS_RC_TABLE))
960 		return;
961 
962 	ieee80211_get_tx_rates(&sdata->vif, ista, txrc->skb,
963 			       info->control.rates,
964 			       ARRAY_SIZE(info->control.rates));
965 }
966 
967 int rate_control_set_rates(struct ieee80211_hw *hw,
968 			   struct ieee80211_sta *pubsta,
969 			   struct ieee80211_sta_rates *rates)
970 {
971 	struct sta_info *sta = container_of(pubsta, struct sta_info, sta);
972 	struct ieee80211_sta_rates *old;
973 	struct ieee80211_supported_band *sband;
974 
975 	sband = ieee80211_get_sband(sta->sdata);
976 	if (!sband)
977 		return -EINVAL;
978 	rate_control_apply_mask_ratetbl(sta, sband, rates);
979 	/*
980 	 * mac80211 guarantees that this function will not be called
981 	 * concurrently, so the following RCU access is safe, even without
982 	 * extra locking. This can not be checked easily, so we just set
983 	 * the condition to true.
984 	 */
985 	old = rcu_dereference_protected(pubsta->rates, true);
986 	rcu_assign_pointer(pubsta->rates, rates);
987 	if (old)
988 		kfree_rcu(old, rcu_head);
989 
990 	if (sta->uploaded)
991 		drv_sta_rate_tbl_update(hw_to_local(hw), sta->sdata, pubsta);
992 
993 	ieee80211_sta_set_expected_throughput(pubsta, sta_get_expected_throughput(sta));
994 
995 	return 0;
996 }
997 EXPORT_SYMBOL(rate_control_set_rates);
998 
999 int ieee80211_init_rate_ctrl_alg(struct ieee80211_local *local,
1000 				 const char *name)
1001 {
1002 	struct rate_control_ref *ref;
1003 
1004 	ASSERT_RTNL();
1005 
1006 	if (local->open_count)
1007 		return -EBUSY;
1008 
1009 	if (ieee80211_hw_check(&local->hw, HAS_RATE_CONTROL)) {
1010 		if (WARN_ON(!local->ops->set_rts_threshold))
1011 			return -EINVAL;
1012 		return 0;
1013 	}
1014 
1015 	ref = rate_control_alloc(name, local);
1016 	if (!ref) {
1017 		wiphy_warn(local->hw.wiphy,
1018 			   "Failed to select rate control algorithm\n");
1019 		return -ENOENT;
1020 	}
1021 
1022 	WARN_ON(local->rate_ctrl);
1023 	local->rate_ctrl = ref;
1024 
1025 	wiphy_debug(local->hw.wiphy, "Selected rate control algorithm '%s'\n",
1026 		    ref->ops->name);
1027 
1028 	return 0;
1029 }
1030 
1031 void rate_control_deinitialize(struct ieee80211_local *local)
1032 {
1033 	struct rate_control_ref *ref;
1034 
1035 	ref = local->rate_ctrl;
1036 
1037 	if (!ref)
1038 		return;
1039 
1040 	local->rate_ctrl = NULL;
1041 	rate_control_free(local, ref);
1042 }
1043