1 // SPDX-License-Identifier: GPL-2.0-or-later 2 3 #include <crypto/hash.h> 4 #include <linux/cpu.h> 5 #include <linux/kref.h> 6 #include <linux/module.h> 7 #include <linux/mutex.h> 8 #include <linux/percpu.h> 9 #include <linux/workqueue.h> 10 #include <net/tcp.h> 11 12 static size_t __scratch_size; 13 static DEFINE_PER_CPU(void __rcu *, sigpool_scratch); 14 15 struct sigpool_entry { 16 struct crypto_ahash *hash; 17 const char *alg; 18 struct kref kref; 19 uint16_t needs_key:1, 20 reserved:15; 21 }; 22 23 #define CPOOL_SIZE (PAGE_SIZE / sizeof(struct sigpool_entry)) 24 static struct sigpool_entry cpool[CPOOL_SIZE]; 25 static unsigned int cpool_populated; 26 static DEFINE_MUTEX(cpool_mutex); 27 28 /* Slow-path */ 29 struct scratches_to_free { 30 struct rcu_head rcu; 31 unsigned int cnt; 32 void *scratches[]; 33 }; 34 35 static void free_old_scratches(struct rcu_head *head) 36 { 37 struct scratches_to_free *stf; 38 39 stf = container_of(head, struct scratches_to_free, rcu); 40 while (stf->cnt--) 41 kfree(stf->scratches[stf->cnt]); 42 kfree(stf); 43 } 44 45 /** 46 * sigpool_reserve_scratch - re-allocates scratch buffer, slow-path 47 * @size: request size for the scratch/temp buffer 48 */ 49 static int sigpool_reserve_scratch(size_t size) 50 { 51 struct scratches_to_free *stf; 52 size_t stf_sz = struct_size(stf, scratches, num_possible_cpus()); 53 int cpu, err = 0; 54 55 lockdep_assert_held(&cpool_mutex); 56 if (__scratch_size >= size) 57 return 0; 58 59 stf = kmalloc(stf_sz, GFP_KERNEL); 60 if (!stf) 61 return -ENOMEM; 62 stf->cnt = 0; 63 64 size = max(size, __scratch_size); 65 cpus_read_lock(); 66 for_each_possible_cpu(cpu) { 67 void *scratch, *old_scratch; 68 69 scratch = kmalloc_node(size, GFP_KERNEL, cpu_to_node(cpu)); 70 if (!scratch) { 71 err = -ENOMEM; 72 break; 73 } 74 75 old_scratch = rcu_replace_pointer(per_cpu(sigpool_scratch, cpu), 76 scratch, lockdep_is_held(&cpool_mutex)); 77 if (!cpu_online(cpu) || !old_scratch) { 78 kfree(old_scratch); 79 continue; 80 } 81 stf->scratches[stf->cnt++] = old_scratch; 82 } 83 cpus_read_unlock(); 84 if (!err) 85 __scratch_size = size; 86 87 call_rcu(&stf->rcu, free_old_scratches); 88 return err; 89 } 90 91 static void sigpool_scratch_free(void) 92 { 93 int cpu; 94 95 for_each_possible_cpu(cpu) 96 kfree(rcu_replace_pointer(per_cpu(sigpool_scratch, cpu), 97 NULL, lockdep_is_held(&cpool_mutex))); 98 __scratch_size = 0; 99 } 100 101 static int __cpool_try_clone(struct crypto_ahash *hash) 102 { 103 struct crypto_ahash *tmp; 104 105 tmp = crypto_clone_ahash(hash); 106 if (IS_ERR(tmp)) 107 return PTR_ERR(tmp); 108 109 crypto_free_ahash(tmp); 110 return 0; 111 } 112 113 static int __cpool_alloc_ahash(struct sigpool_entry *e, const char *alg) 114 { 115 struct crypto_ahash *cpu0_hash; 116 int ret; 117 118 e->alg = kstrdup(alg, GFP_KERNEL); 119 if (!e->alg) 120 return -ENOMEM; 121 122 cpu0_hash = crypto_alloc_ahash(alg, 0, CRYPTO_ALG_ASYNC); 123 if (IS_ERR(cpu0_hash)) { 124 ret = PTR_ERR(cpu0_hash); 125 goto out_free_alg; 126 } 127 128 e->needs_key = crypto_ahash_get_flags(cpu0_hash) & CRYPTO_TFM_NEED_KEY; 129 130 ret = __cpool_try_clone(cpu0_hash); 131 if (ret) 132 goto out_free_cpu0_hash; 133 e->hash = cpu0_hash; 134 kref_init(&e->kref); 135 return 0; 136 137 out_free_cpu0_hash: 138 crypto_free_ahash(cpu0_hash); 139 out_free_alg: 140 kfree(e->alg); 141 e->alg = NULL; 142 return ret; 143 } 144 145 /** 146 * tcp_sigpool_alloc_ahash - allocates pool for ahash requests 147 * @alg: name of async hash algorithm 148 * @scratch_size: reserve a tcp_sigpool::scratch buffer of this size 149 */ 150 int tcp_sigpool_alloc_ahash(const char *alg, size_t scratch_size) 151 { 152 int i, ret; 153 154 /* slow-path */ 155 mutex_lock(&cpool_mutex); 156 ret = sigpool_reserve_scratch(scratch_size); 157 if (ret) 158 goto out; 159 for (i = 0; i < cpool_populated; i++) { 160 if (!cpool[i].alg) 161 continue; 162 if (strcmp(cpool[i].alg, alg)) 163 continue; 164 165 /* pairs with tcp_sigpool_release() */ 166 if (!kref_get_unless_zero(&cpool[i].kref)) 167 kref_init(&cpool[i].kref); 168 ret = i; 169 goto out; 170 } 171 172 for (i = 0; i < cpool_populated; i++) { 173 if (!cpool[i].alg) 174 break; 175 } 176 if (i >= CPOOL_SIZE) { 177 ret = -ENOSPC; 178 goto out; 179 } 180 181 ret = __cpool_alloc_ahash(&cpool[i], alg); 182 if (!ret) { 183 ret = i; 184 if (i == cpool_populated) 185 cpool_populated++; 186 } 187 out: 188 mutex_unlock(&cpool_mutex); 189 return ret; 190 } 191 EXPORT_SYMBOL_GPL(tcp_sigpool_alloc_ahash); 192 193 static void __cpool_free_entry(struct sigpool_entry *e) 194 { 195 crypto_free_ahash(e->hash); 196 kfree(e->alg); 197 memset(e, 0, sizeof(*e)); 198 } 199 200 static void cpool_cleanup_work_cb(struct work_struct *work) 201 { 202 bool free_scratch = true; 203 unsigned int i; 204 205 mutex_lock(&cpool_mutex); 206 for (i = 0; i < cpool_populated; i++) { 207 if (kref_read(&cpool[i].kref) > 0) { 208 free_scratch = false; 209 continue; 210 } 211 if (!cpool[i].alg) 212 continue; 213 __cpool_free_entry(&cpool[i]); 214 } 215 if (free_scratch) 216 sigpool_scratch_free(); 217 mutex_unlock(&cpool_mutex); 218 } 219 220 static DECLARE_WORK(cpool_cleanup_work, cpool_cleanup_work_cb); 221 static void cpool_schedule_cleanup(struct kref *kref) 222 { 223 schedule_work(&cpool_cleanup_work); 224 } 225 226 /** 227 * tcp_sigpool_release - decreases number of users for a pool. If it was 228 * the last user of the pool, releases any memory that was consumed. 229 * @id: tcp_sigpool that was previously allocated by tcp_sigpool_alloc_ahash() 230 */ 231 void tcp_sigpool_release(unsigned int id) 232 { 233 if (WARN_ON_ONCE(id >= cpool_populated || !cpool[id].alg)) 234 return; 235 236 /* slow-path */ 237 kref_put(&cpool[id].kref, cpool_schedule_cleanup); 238 } 239 EXPORT_SYMBOL_GPL(tcp_sigpool_release); 240 241 /** 242 * tcp_sigpool_get - increases number of users (refcounter) for a pool 243 * @id: tcp_sigpool that was previously allocated by tcp_sigpool_alloc_ahash() 244 */ 245 void tcp_sigpool_get(unsigned int id) 246 { 247 if (WARN_ON_ONCE(id >= cpool_populated || !cpool[id].alg)) 248 return; 249 kref_get(&cpool[id].kref); 250 } 251 EXPORT_SYMBOL_GPL(tcp_sigpool_get); 252 253 int tcp_sigpool_start(unsigned int id, struct tcp_sigpool *c) __cond_acquires(RCU_BH) 254 { 255 struct crypto_ahash *hash; 256 257 rcu_read_lock_bh(); 258 if (WARN_ON_ONCE(id >= cpool_populated || !cpool[id].alg)) { 259 rcu_read_unlock_bh(); 260 return -EINVAL; 261 } 262 263 hash = crypto_clone_ahash(cpool[id].hash); 264 if (IS_ERR(hash)) { 265 rcu_read_unlock_bh(); 266 return PTR_ERR(hash); 267 } 268 269 c->req = ahash_request_alloc(hash, GFP_ATOMIC); 270 if (!c->req) { 271 crypto_free_ahash(hash); 272 rcu_read_unlock_bh(); 273 return -ENOMEM; 274 } 275 ahash_request_set_callback(c->req, 0, NULL, NULL); 276 277 /* Pairs with tcp_sigpool_reserve_scratch(), scratch area is 278 * valid (allocated) until tcp_sigpool_end(). 279 */ 280 c->scratch = rcu_dereference_bh(*this_cpu_ptr(&sigpool_scratch)); 281 return 0; 282 } 283 EXPORT_SYMBOL_GPL(tcp_sigpool_start); 284 285 void tcp_sigpool_end(struct tcp_sigpool *c) __releases(RCU_BH) 286 { 287 struct crypto_ahash *hash = crypto_ahash_reqtfm(c->req); 288 289 rcu_read_unlock_bh(); 290 ahash_request_free(c->req); 291 crypto_free_ahash(hash); 292 } 293 EXPORT_SYMBOL_GPL(tcp_sigpool_end); 294 295 /** 296 * tcp_sigpool_algo - return algorithm of tcp_sigpool 297 * @id: tcp_sigpool that was previously allocated by tcp_sigpool_alloc_ahash() 298 * @buf: buffer to return name of algorithm 299 * @buf_len: size of @buf 300 */ 301 size_t tcp_sigpool_algo(unsigned int id, char *buf, size_t buf_len) 302 { 303 if (WARN_ON_ONCE(id >= cpool_populated || !cpool[id].alg)) 304 return -EINVAL; 305 306 return strscpy(buf, cpool[id].alg, buf_len); 307 } 308 EXPORT_SYMBOL_GPL(tcp_sigpool_algo); 309 310 /** 311 * tcp_sigpool_hash_skb_data - hash data in skb with initialized tcp_sigpool 312 * @hp: tcp_sigpool pointer 313 * @skb: buffer to add sign for 314 * @header_len: TCP header length for this segment 315 */ 316 int tcp_sigpool_hash_skb_data(struct tcp_sigpool *hp, 317 const struct sk_buff *skb, 318 unsigned int header_len) 319 { 320 const unsigned int head_data_len = skb_headlen(skb) > header_len ? 321 skb_headlen(skb) - header_len : 0; 322 const struct skb_shared_info *shi = skb_shinfo(skb); 323 const struct tcphdr *tp = tcp_hdr(skb); 324 struct ahash_request *req = hp->req; 325 struct sk_buff *frag_iter; 326 struct scatterlist sg; 327 unsigned int i; 328 329 sg_init_table(&sg, 1); 330 331 sg_set_buf(&sg, ((u8 *)tp) + header_len, head_data_len); 332 ahash_request_set_crypt(req, &sg, NULL, head_data_len); 333 if (crypto_ahash_update(req)) 334 return 1; 335 336 for (i = 0; i < shi->nr_frags; ++i) { 337 const skb_frag_t *f = &shi->frags[i]; 338 unsigned int offset = skb_frag_off(f); 339 struct page *page; 340 341 page = skb_frag_page(f) + (offset >> PAGE_SHIFT); 342 sg_set_page(&sg, page, skb_frag_size(f), offset_in_page(offset)); 343 ahash_request_set_crypt(req, &sg, NULL, skb_frag_size(f)); 344 if (crypto_ahash_update(req)) 345 return 1; 346 } 347 348 skb_walk_frags(skb, frag_iter) 349 if (tcp_sigpool_hash_skb_data(hp, frag_iter, 0)) 350 return 1; 351 352 return 0; 353 } 354 EXPORT_SYMBOL(tcp_sigpool_hash_skb_data); 355 356 MODULE_LICENSE("GPL"); 357 MODULE_DESCRIPTION("Per-CPU pool of crypto requests"); 358