1 // SPDX-License-Identifier: GPL-2.0 2 /* Multipath TCP token management 3 * Copyright (c) 2017 - 2019, Intel Corporation. 4 * 5 * Note: This code is based on mptcp_ctrl.c from multipath-tcp.org, 6 * authored by: 7 * 8 * Sébastien Barré <sebastien.barre@uclouvain.be> 9 * Christoph Paasch <christoph.paasch@uclouvain.be> 10 * Jaakko Korkeaniemi <jaakko.korkeaniemi@aalto.fi> 11 * Gregory Detal <gregory.detal@uclouvain.be> 12 * Fabien Duchêne <fabien.duchene@uclouvain.be> 13 * Andreas Seelinger <Andreas.Seelinger@rwth-aachen.de> 14 * Lavkesh Lahngir <lavkesh51@gmail.com> 15 * Andreas Ripke <ripke@neclab.eu> 16 * Vlad Dogaru <vlad.dogaru@intel.com> 17 * Octavian Purdila <octavian.purdila@intel.com> 18 * John Ronan <jronan@tssg.org> 19 * Catalin Nicutar <catalin.nicutar@gmail.com> 20 * Brandon Heller <brandonh@stanford.edu> 21 */ 22 23 #define pr_fmt(fmt) "MPTCP: " fmt 24 25 #include <linux/kernel.h> 26 #include <linux/module.h> 27 #include <linux/memblock.h> 28 #include <linux/ip.h> 29 #include <linux/tcp.h> 30 #include <net/sock.h> 31 #include <net/inet_common.h> 32 #include <net/protocol.h> 33 #include <net/mptcp.h> 34 #include "protocol.h" 35 36 #define TOKEN_MAX_RETRIES 4 37 #define TOKEN_MAX_CHAIN_LEN 4 38 39 struct token_bucket { 40 spinlock_t lock; 41 int chain_len; 42 struct hlist_nulls_head req_chain; 43 struct hlist_nulls_head msk_chain; 44 }; 45 46 static struct token_bucket *token_hash __read_mostly; 47 static unsigned int token_mask __read_mostly; 48 49 static struct token_bucket *token_bucket(u32 token) 50 { 51 return &token_hash[token & token_mask]; 52 } 53 54 /* called with bucket lock held */ 55 static struct mptcp_subflow_request_sock * 56 __token_lookup_req(struct token_bucket *t, u32 token) 57 { 58 struct mptcp_subflow_request_sock *req; 59 struct hlist_nulls_node *pos; 60 61 hlist_nulls_for_each_entry_rcu(req, pos, &t->req_chain, token_node) 62 if (req->token == token) 63 return req; 64 return NULL; 65 } 66 67 /* called with bucket lock held */ 68 static struct mptcp_sock * 69 __token_lookup_msk(struct token_bucket *t, u32 token) 70 { 71 struct hlist_nulls_node *pos; 72 struct sock *sk; 73 74 sk_nulls_for_each_rcu(sk, pos, &t->msk_chain) 75 if (mptcp_sk(sk)->token == token) 76 return mptcp_sk(sk); 77 return NULL; 78 } 79 80 static bool __token_bucket_busy(struct token_bucket *t, u32 token) 81 { 82 return !token || t->chain_len >= TOKEN_MAX_CHAIN_LEN || 83 __token_lookup_req(t, token) || __token_lookup_msk(t, token); 84 } 85 86 static void mptcp_crypto_key_gen_sha(u64 *key, u32 *token, u64 *idsn) 87 { 88 /* we might consider a faster version that computes the key as a 89 * hash of some information available in the MPTCP socket. Use 90 * random data at the moment, as it's probably the safest option 91 * in case multiple sockets are opened in different namespaces at 92 * the same time. 93 */ 94 get_random_bytes(key, sizeof(u64)); 95 mptcp_crypto_key_sha(*key, token, idsn); 96 } 97 98 /** 99 * mptcp_token_new_request - create new key/idsn/token for subflow_request 100 * @req: the request socket 101 * 102 * This function is called when a new mptcp connection is coming in. 103 * 104 * It creates a unique token to identify the new mptcp connection, 105 * a secret local key and the initial data sequence number (idsn). 106 * 107 * Returns 0 on success. 108 */ 109 int mptcp_token_new_request(struct request_sock *req) 110 { 111 struct mptcp_subflow_request_sock *subflow_req = mptcp_subflow_rsk(req); 112 struct token_bucket *bucket; 113 u32 token; 114 115 mptcp_crypto_key_sha(subflow_req->local_key, 116 &subflow_req->token, 117 &subflow_req->idsn); 118 pr_debug("req=%p local_key=%llu, token=%u, idsn=%llu\n", 119 req, subflow_req->local_key, subflow_req->token, 120 subflow_req->idsn); 121 122 token = subflow_req->token; 123 bucket = token_bucket(token); 124 spin_lock_bh(&bucket->lock); 125 if (__token_bucket_busy(bucket, token)) { 126 spin_unlock_bh(&bucket->lock); 127 return -EBUSY; 128 } 129 130 hlist_nulls_add_head_rcu(&subflow_req->token_node, &bucket->req_chain); 131 bucket->chain_len++; 132 spin_unlock_bh(&bucket->lock); 133 return 0; 134 } 135 136 /** 137 * mptcp_token_new_connect - create new key/idsn/token for subflow 138 * @sk: the socket that will initiate a connection 139 * 140 * This function is called when a new outgoing mptcp connection is 141 * initiated. 142 * 143 * It creates a unique token to identify the new mptcp connection, 144 * a secret local key and the initial data sequence number (idsn). 145 * 146 * On success, the mptcp connection can be found again using 147 * the computed token at a later time, this is needed to process 148 * join requests. 149 * 150 * returns 0 on success. 151 */ 152 int mptcp_token_new_connect(struct sock *sk) 153 { 154 struct mptcp_subflow_context *subflow = mptcp_subflow_ctx(sk); 155 struct mptcp_sock *msk = mptcp_sk(subflow->conn); 156 int retries = TOKEN_MAX_RETRIES; 157 struct token_bucket *bucket; 158 159 pr_debug("ssk=%p, local_key=%llu, token=%u, idsn=%llu\n", 160 sk, subflow->local_key, subflow->token, subflow->idsn); 161 162 again: 163 mptcp_crypto_key_gen_sha(&subflow->local_key, &subflow->token, 164 &subflow->idsn); 165 166 bucket = token_bucket(subflow->token); 167 spin_lock_bh(&bucket->lock); 168 if (__token_bucket_busy(bucket, subflow->token)) { 169 spin_unlock_bh(&bucket->lock); 170 if (!--retries) 171 return -EBUSY; 172 goto again; 173 } 174 175 WRITE_ONCE(msk->token, subflow->token); 176 __sk_nulls_add_node_rcu((struct sock *)msk, &bucket->msk_chain); 177 bucket->chain_len++; 178 spin_unlock_bh(&bucket->lock); 179 return 0; 180 } 181 182 /** 183 * mptcp_token_accept - replace a req sk with full sock in token hash 184 * @req: the request socket to be removed 185 * @msk: the just cloned socket linked to the new connection 186 * 187 * Called when a SYN packet creates a new logical connection, i.e. 188 * is not a join request. 189 */ 190 void mptcp_token_accept(struct mptcp_subflow_request_sock *req, 191 struct mptcp_sock *msk) 192 { 193 struct mptcp_subflow_request_sock *pos; 194 struct token_bucket *bucket; 195 196 bucket = token_bucket(req->token); 197 spin_lock_bh(&bucket->lock); 198 199 /* pedantic lookup check for the moved token */ 200 pos = __token_lookup_req(bucket, req->token); 201 if (!WARN_ON_ONCE(pos != req)) 202 hlist_nulls_del_init_rcu(&req->token_node); 203 __sk_nulls_add_node_rcu((struct sock *)msk, &bucket->msk_chain); 204 spin_unlock_bh(&bucket->lock); 205 } 206 207 bool mptcp_token_exists(u32 token) 208 { 209 struct hlist_nulls_node *pos; 210 struct token_bucket *bucket; 211 struct mptcp_sock *msk; 212 struct sock *sk; 213 214 rcu_read_lock(); 215 bucket = token_bucket(token); 216 217 again: 218 sk_nulls_for_each_rcu(sk, pos, &bucket->msk_chain) { 219 msk = mptcp_sk(sk); 220 if (READ_ONCE(msk->token) == token) 221 goto found; 222 } 223 if (get_nulls_value(pos) != (token & token_mask)) 224 goto again; 225 226 rcu_read_unlock(); 227 return false; 228 found: 229 rcu_read_unlock(); 230 return true; 231 } 232 233 /** 234 * mptcp_token_get_sock - retrieve mptcp connection sock using its token 235 * @token: token of the mptcp connection to retrieve 236 * 237 * This function returns the mptcp connection structure with the given token. 238 * A reference count on the mptcp socket returned is taken. 239 * 240 * returns NULL if no connection with the given token value exists. 241 */ 242 struct mptcp_sock *mptcp_token_get_sock(u32 token) 243 { 244 struct hlist_nulls_node *pos; 245 struct token_bucket *bucket; 246 struct mptcp_sock *msk; 247 struct sock *sk; 248 249 rcu_read_lock(); 250 bucket = token_bucket(token); 251 252 again: 253 sk_nulls_for_each_rcu(sk, pos, &bucket->msk_chain) { 254 msk = mptcp_sk(sk); 255 if (READ_ONCE(msk->token) != token) 256 continue; 257 if (!refcount_inc_not_zero(&sk->sk_refcnt)) 258 goto not_found; 259 if (READ_ONCE(msk->token) != token) { 260 sock_put(sk); 261 goto again; 262 } 263 goto found; 264 } 265 if (get_nulls_value(pos) != (token & token_mask)) 266 goto again; 267 268 not_found: 269 msk = NULL; 270 271 found: 272 rcu_read_unlock(); 273 return msk; 274 } 275 EXPORT_SYMBOL_GPL(mptcp_token_get_sock); 276 277 /** 278 * mptcp_token_iter_next - iterate over the token container from given pos 279 * @net: namespace to be iterated 280 * @s_slot: start slot number 281 * @s_num: start number inside the given lock 282 * 283 * This function returns the first mptcp connection structure found inside the 284 * token container starting from the specified position, or NULL. 285 * 286 * On successful iteration, the iterator is move to the next position and the 287 * the acquires a reference to the returned socket. 288 */ 289 struct mptcp_sock *mptcp_token_iter_next(const struct net *net, long *s_slot, 290 long *s_num) 291 { 292 struct mptcp_sock *ret = NULL; 293 struct hlist_nulls_node *pos; 294 int slot, num = 0; 295 296 for (slot = *s_slot; slot <= token_mask; *s_num = 0, slot++) { 297 struct token_bucket *bucket = &token_hash[slot]; 298 struct sock *sk; 299 300 num = 0; 301 302 if (hlist_nulls_empty(&bucket->msk_chain)) 303 continue; 304 305 rcu_read_lock(); 306 sk_nulls_for_each_rcu(sk, pos, &bucket->msk_chain) { 307 ++num; 308 if (!net_eq(sock_net(sk), net)) 309 continue; 310 311 if (num <= *s_num) 312 continue; 313 314 if (!refcount_inc_not_zero(&sk->sk_refcnt)) 315 continue; 316 317 if (!net_eq(sock_net(sk), net)) { 318 sock_put(sk); 319 continue; 320 } 321 322 ret = mptcp_sk(sk); 323 rcu_read_unlock(); 324 goto out; 325 } 326 rcu_read_unlock(); 327 } 328 329 out: 330 *s_slot = slot; 331 *s_num = num; 332 return ret; 333 } 334 EXPORT_SYMBOL_GPL(mptcp_token_iter_next); 335 336 /** 337 * mptcp_token_destroy_request - remove mptcp connection/token 338 * @req: mptcp request socket dropping the token 339 * 340 * Remove the token associated to @req. 341 */ 342 void mptcp_token_destroy_request(struct request_sock *req) 343 { 344 struct mptcp_subflow_request_sock *subflow_req = mptcp_subflow_rsk(req); 345 struct mptcp_subflow_request_sock *pos; 346 struct token_bucket *bucket; 347 348 if (hlist_nulls_unhashed(&subflow_req->token_node)) 349 return; 350 351 bucket = token_bucket(subflow_req->token); 352 spin_lock_bh(&bucket->lock); 353 pos = __token_lookup_req(bucket, subflow_req->token); 354 if (!WARN_ON_ONCE(pos != subflow_req)) { 355 hlist_nulls_del_init_rcu(&pos->token_node); 356 bucket->chain_len--; 357 } 358 spin_unlock_bh(&bucket->lock); 359 } 360 361 /** 362 * mptcp_token_destroy - remove mptcp connection/token 363 * @msk: mptcp connection dropping the token 364 * 365 * Remove the token associated to @msk 366 */ 367 void mptcp_token_destroy(struct mptcp_sock *msk) 368 { 369 struct token_bucket *bucket; 370 struct mptcp_sock *pos; 371 372 if (sk_unhashed((struct sock *)msk)) 373 return; 374 375 bucket = token_bucket(msk->token); 376 spin_lock_bh(&bucket->lock); 377 pos = __token_lookup_msk(bucket, msk->token); 378 if (!WARN_ON_ONCE(pos != msk)) { 379 __sk_nulls_del_node_init_rcu((struct sock *)pos); 380 bucket->chain_len--; 381 } 382 spin_unlock_bh(&bucket->lock); 383 } 384 385 void __init mptcp_token_init(void) 386 { 387 int i; 388 389 token_hash = alloc_large_system_hash("MPTCP token", 390 sizeof(struct token_bucket), 391 0, 392 20,/* one slot per 1MB of memory */ 393 HASH_ZERO, 394 NULL, 395 &token_mask, 396 0, 397 64 * 1024); 398 for (i = 0; i < token_mask + 1; ++i) { 399 INIT_HLIST_NULLS_HEAD(&token_hash[i].req_chain, i); 400 INIT_HLIST_NULLS_HEAD(&token_hash[i].msk_chain, i); 401 spin_lock_init(&token_hash[i].lock); 402 } 403 } 404 405 #if IS_MODULE(CONFIG_MPTCP_KUNIT_TESTS) 406 EXPORT_SYMBOL_GPL(mptcp_token_new_request); 407 EXPORT_SYMBOL_GPL(mptcp_token_new_connect); 408 EXPORT_SYMBOL_GPL(mptcp_token_accept); 409 EXPORT_SYMBOL_GPL(mptcp_token_destroy_request); 410 EXPORT_SYMBOL_GPL(mptcp_token_destroy); 411 #endif 412