1 // SPDX-License-Identifier: GPL-2.0-or-later 2 3 #include <crypto/hash.h> 4 #include <linux/cpu.h> 5 #include <linux/kref.h> 6 #include <linux/module.h> 7 #include <linux/mutex.h> 8 #include <linux/percpu.h> 9 #include <linux/workqueue.h> 10 #include <net/tcp.h> 11 12 static size_t __scratch_size; 13 struct sigpool_scratch { 14 local_lock_t bh_lock; 15 void __rcu *pad; 16 }; 17 18 static DEFINE_PER_CPU(struct sigpool_scratch, sigpool_scratch) = { 19 .bh_lock = INIT_LOCAL_LOCK(bh_lock), 20 }; 21 22 struct sigpool_entry { 23 struct crypto_ahash *hash; 24 const char *alg; 25 struct kref kref; 26 uint16_t needs_key:1, 27 reserved:15; 28 }; 29 30 #define CPOOL_SIZE (PAGE_SIZE / sizeof(struct sigpool_entry)) 31 static struct sigpool_entry cpool[CPOOL_SIZE]; 32 static unsigned int cpool_populated; 33 static DEFINE_MUTEX(cpool_mutex); 34 35 /* Slow-path */ 36 struct scratches_to_free { 37 struct rcu_head rcu; 38 unsigned int cnt; 39 void *scratches[]; 40 }; 41 42 static void free_old_scratches(struct rcu_head *head) 43 { 44 struct scratches_to_free *stf; 45 46 stf = container_of(head, struct scratches_to_free, rcu); 47 while (stf->cnt--) 48 kfree(stf->scratches[stf->cnt]); 49 kfree(stf); 50 } 51 52 /** 53 * sigpool_reserve_scratch - re-allocates scratch buffer, slow-path 54 * @size: request size for the scratch/temp buffer 55 */ 56 static int sigpool_reserve_scratch(size_t size) 57 { 58 struct scratches_to_free *stf; 59 size_t stf_sz = struct_size(stf, scratches, num_possible_cpus()); 60 int cpu, err = 0; 61 62 lockdep_assert_held(&cpool_mutex); 63 if (__scratch_size >= size) 64 return 0; 65 66 stf = kmalloc(stf_sz, GFP_KERNEL); 67 if (!stf) 68 return -ENOMEM; 69 stf->cnt = 0; 70 71 size = max(size, __scratch_size); 72 cpus_read_lock(); 73 for_each_possible_cpu(cpu) { 74 void *scratch, *old_scratch; 75 76 scratch = kmalloc_node(size, GFP_KERNEL, cpu_to_node(cpu)); 77 if (!scratch) { 78 err = -ENOMEM; 79 break; 80 } 81 82 old_scratch = rcu_replace_pointer(per_cpu(sigpool_scratch.pad, cpu), 83 scratch, lockdep_is_held(&cpool_mutex)); 84 if (!cpu_online(cpu) || !old_scratch) { 85 kfree(old_scratch); 86 continue; 87 } 88 stf->scratches[stf->cnt++] = old_scratch; 89 } 90 cpus_read_unlock(); 91 if (!err) 92 __scratch_size = size; 93 94 call_rcu(&stf->rcu, free_old_scratches); 95 return err; 96 } 97 98 static void sigpool_scratch_free(void) 99 { 100 int cpu; 101 102 for_each_possible_cpu(cpu) 103 kfree(rcu_replace_pointer(per_cpu(sigpool_scratch.pad, cpu), 104 NULL, lockdep_is_held(&cpool_mutex))); 105 __scratch_size = 0; 106 } 107 108 static int __cpool_try_clone(struct crypto_ahash *hash) 109 { 110 struct crypto_ahash *tmp; 111 112 tmp = crypto_clone_ahash(hash); 113 if (IS_ERR(tmp)) 114 return PTR_ERR(tmp); 115 116 crypto_free_ahash(tmp); 117 return 0; 118 } 119 120 static int __cpool_alloc_ahash(struct sigpool_entry *e, const char *alg) 121 { 122 struct crypto_ahash *cpu0_hash; 123 int ret; 124 125 e->alg = kstrdup(alg, GFP_KERNEL); 126 if (!e->alg) 127 return -ENOMEM; 128 129 cpu0_hash = crypto_alloc_ahash(alg, 0, CRYPTO_ALG_ASYNC); 130 if (IS_ERR(cpu0_hash)) { 131 ret = PTR_ERR(cpu0_hash); 132 goto out_free_alg; 133 } 134 135 e->needs_key = crypto_ahash_get_flags(cpu0_hash) & CRYPTO_TFM_NEED_KEY; 136 137 ret = __cpool_try_clone(cpu0_hash); 138 if (ret) 139 goto out_free_cpu0_hash; 140 e->hash = cpu0_hash; 141 kref_init(&e->kref); 142 return 0; 143 144 out_free_cpu0_hash: 145 crypto_free_ahash(cpu0_hash); 146 out_free_alg: 147 kfree(e->alg); 148 e->alg = NULL; 149 return ret; 150 } 151 152 /** 153 * tcp_sigpool_alloc_ahash - allocates pool for ahash requests 154 * @alg: name of async hash algorithm 155 * @scratch_size: reserve a tcp_sigpool::scratch buffer of this size 156 */ 157 int tcp_sigpool_alloc_ahash(const char *alg, size_t scratch_size) 158 { 159 int i, ret; 160 161 /* slow-path */ 162 mutex_lock(&cpool_mutex); 163 ret = sigpool_reserve_scratch(scratch_size); 164 if (ret) 165 goto out; 166 for (i = 0; i < cpool_populated; i++) { 167 if (!cpool[i].alg) 168 continue; 169 if (strcmp(cpool[i].alg, alg)) 170 continue; 171 172 /* pairs with tcp_sigpool_release() */ 173 if (!kref_get_unless_zero(&cpool[i].kref)) 174 kref_init(&cpool[i].kref); 175 ret = i; 176 goto out; 177 } 178 179 for (i = 0; i < cpool_populated; i++) { 180 if (!cpool[i].alg) 181 break; 182 } 183 if (i >= CPOOL_SIZE) { 184 ret = -ENOSPC; 185 goto out; 186 } 187 188 ret = __cpool_alloc_ahash(&cpool[i], alg); 189 if (!ret) { 190 ret = i; 191 if (i == cpool_populated) 192 cpool_populated++; 193 } 194 out: 195 mutex_unlock(&cpool_mutex); 196 return ret; 197 } 198 EXPORT_SYMBOL_GPL(tcp_sigpool_alloc_ahash); 199 200 static void __cpool_free_entry(struct sigpool_entry *e) 201 { 202 crypto_free_ahash(e->hash); 203 kfree(e->alg); 204 memset(e, 0, sizeof(*e)); 205 } 206 207 static void cpool_cleanup_work_cb(struct work_struct *work) 208 { 209 bool free_scratch = true; 210 unsigned int i; 211 212 mutex_lock(&cpool_mutex); 213 for (i = 0; i < cpool_populated; i++) { 214 if (kref_read(&cpool[i].kref) > 0) { 215 free_scratch = false; 216 continue; 217 } 218 if (!cpool[i].alg) 219 continue; 220 __cpool_free_entry(&cpool[i]); 221 } 222 if (free_scratch) 223 sigpool_scratch_free(); 224 mutex_unlock(&cpool_mutex); 225 } 226 227 static DECLARE_WORK(cpool_cleanup_work, cpool_cleanup_work_cb); 228 static void cpool_schedule_cleanup(struct kref *kref) 229 { 230 schedule_work(&cpool_cleanup_work); 231 } 232 233 /** 234 * tcp_sigpool_release - decreases number of users for a pool. If it was 235 * the last user of the pool, releases any memory that was consumed. 236 * @id: tcp_sigpool that was previously allocated by tcp_sigpool_alloc_ahash() 237 */ 238 void tcp_sigpool_release(unsigned int id) 239 { 240 if (WARN_ON_ONCE(id >= cpool_populated || !cpool[id].alg)) 241 return; 242 243 /* slow-path */ 244 kref_put(&cpool[id].kref, cpool_schedule_cleanup); 245 } 246 EXPORT_SYMBOL_GPL(tcp_sigpool_release); 247 248 /** 249 * tcp_sigpool_get - increases number of users (refcounter) for a pool 250 * @id: tcp_sigpool that was previously allocated by tcp_sigpool_alloc_ahash() 251 */ 252 void tcp_sigpool_get(unsigned int id) 253 { 254 if (WARN_ON_ONCE(id >= cpool_populated || !cpool[id].alg)) 255 return; 256 kref_get(&cpool[id].kref); 257 } 258 EXPORT_SYMBOL_GPL(tcp_sigpool_get); 259 260 int tcp_sigpool_start(unsigned int id, struct tcp_sigpool *c) __cond_acquires(RCU_BH) 261 { 262 struct crypto_ahash *hash; 263 264 rcu_read_lock_bh(); 265 if (WARN_ON_ONCE(id >= cpool_populated || !cpool[id].alg)) { 266 rcu_read_unlock_bh(); 267 return -EINVAL; 268 } 269 270 hash = crypto_clone_ahash(cpool[id].hash); 271 if (IS_ERR(hash)) { 272 rcu_read_unlock_bh(); 273 return PTR_ERR(hash); 274 } 275 276 c->req = ahash_request_alloc(hash, GFP_ATOMIC); 277 if (!c->req) { 278 crypto_free_ahash(hash); 279 rcu_read_unlock_bh(); 280 return -ENOMEM; 281 } 282 ahash_request_set_callback(c->req, 0, NULL, NULL); 283 284 /* Pairs with tcp_sigpool_reserve_scratch(), scratch area is 285 * valid (allocated) until tcp_sigpool_end(). 286 */ 287 local_lock_nested_bh(&sigpool_scratch.bh_lock); 288 c->scratch = rcu_dereference_bh(*this_cpu_ptr(&sigpool_scratch.pad)); 289 return 0; 290 } 291 EXPORT_SYMBOL_GPL(tcp_sigpool_start); 292 293 void tcp_sigpool_end(struct tcp_sigpool *c) __releases(RCU_BH) 294 { 295 struct crypto_ahash *hash = crypto_ahash_reqtfm(c->req); 296 297 local_unlock_nested_bh(&sigpool_scratch.bh_lock); 298 rcu_read_unlock_bh(); 299 ahash_request_free(c->req); 300 crypto_free_ahash(hash); 301 } 302 EXPORT_SYMBOL_GPL(tcp_sigpool_end); 303 304 /** 305 * tcp_sigpool_algo - return algorithm of tcp_sigpool 306 * @id: tcp_sigpool that was previously allocated by tcp_sigpool_alloc_ahash() 307 * @buf: buffer to return name of algorithm 308 * @buf_len: size of @buf 309 */ 310 size_t tcp_sigpool_algo(unsigned int id, char *buf, size_t buf_len) 311 { 312 if (WARN_ON_ONCE(id >= cpool_populated || !cpool[id].alg)) 313 return -EINVAL; 314 315 return strscpy(buf, cpool[id].alg, buf_len); 316 } 317 EXPORT_SYMBOL_GPL(tcp_sigpool_algo); 318 319 /** 320 * tcp_sigpool_hash_skb_data - hash data in skb with initialized tcp_sigpool 321 * @hp: tcp_sigpool pointer 322 * @skb: buffer to add sign for 323 * @header_len: TCP header length for this segment 324 */ 325 int tcp_sigpool_hash_skb_data(struct tcp_sigpool *hp, 326 const struct sk_buff *skb, 327 unsigned int header_len) 328 { 329 const unsigned int head_data_len = skb_headlen(skb) > header_len ? 330 skb_headlen(skb) - header_len : 0; 331 const struct skb_shared_info *shi = skb_shinfo(skb); 332 const struct tcphdr *tp = tcp_hdr(skb); 333 struct ahash_request *req = hp->req; 334 struct sk_buff *frag_iter; 335 struct scatterlist sg; 336 unsigned int i; 337 338 sg_init_table(&sg, 1); 339 340 sg_set_buf(&sg, ((u8 *)tp) + header_len, head_data_len); 341 ahash_request_set_crypt(req, &sg, NULL, head_data_len); 342 if (crypto_ahash_update(req)) 343 return 1; 344 345 for (i = 0; i < shi->nr_frags; ++i) { 346 const skb_frag_t *f = &shi->frags[i]; 347 unsigned int offset = skb_frag_off(f); 348 struct page *page; 349 350 page = skb_frag_page(f) + (offset >> PAGE_SHIFT); 351 sg_set_page(&sg, page, skb_frag_size(f), offset_in_page(offset)); 352 ahash_request_set_crypt(req, &sg, NULL, skb_frag_size(f)); 353 if (crypto_ahash_update(req)) 354 return 1; 355 } 356 357 skb_walk_frags(skb, frag_iter) 358 if (tcp_sigpool_hash_skb_data(hp, frag_iter, 0)) 359 return 1; 360 361 return 0; 362 } 363 EXPORT_SYMBOL(tcp_sigpool_hash_skb_data); 364 365 MODULE_LICENSE("GPL"); 366 MODULE_DESCRIPTION("Per-CPU pool of crypto requests"); 367