1 // SPDX-License-Identifier: GPL-2.0-or-later 2 3 #include <crypto/hash.h> 4 #include <linux/cpu.h> 5 #include <linux/kref.h> 6 #include <linux/module.h> 7 #include <linux/mutex.h> 8 #include <linux/percpu.h> 9 #include <linux/workqueue.h> 10 #include <net/tcp.h> 11 12 static size_t __scratch_size; 13 static DEFINE_PER_CPU(void __rcu *, sigpool_scratch); 14 15 struct sigpool_entry { 16 struct crypto_ahash *hash; 17 const char *alg; 18 struct kref kref; 19 uint16_t needs_key:1, 20 reserved:15; 21 }; 22 23 #define CPOOL_SIZE (PAGE_SIZE / sizeof(struct sigpool_entry)) 24 static struct sigpool_entry cpool[CPOOL_SIZE]; 25 static unsigned int cpool_populated; 26 static DEFINE_MUTEX(cpool_mutex); 27 28 /* Slow-path */ 29 struct scratches_to_free { 30 struct rcu_head rcu; 31 unsigned int cnt; 32 void *scratches[]; 33 }; 34 35 static void free_old_scratches(struct rcu_head *head) 36 { 37 struct scratches_to_free *stf; 38 39 stf = container_of(head, struct scratches_to_free, rcu); 40 while (stf->cnt--) 41 kfree(stf->scratches[stf->cnt]); 42 kfree(stf); 43 } 44 45 /** 46 * sigpool_reserve_scratch - re-allocates scratch buffer, slow-path 47 * @size: request size for the scratch/temp buffer 48 */ 49 static int sigpool_reserve_scratch(size_t size) 50 { 51 struct scratches_to_free *stf; 52 size_t stf_sz = struct_size(stf, scratches, num_possible_cpus()); 53 int cpu, err = 0; 54 55 lockdep_assert_held(&cpool_mutex); 56 if (__scratch_size >= size) 57 return 0; 58 59 stf = kmalloc(stf_sz, GFP_KERNEL); 60 if (!stf) 61 return -ENOMEM; 62 stf->cnt = 0; 63 64 size = max(size, __scratch_size); 65 cpus_read_lock(); 66 for_each_possible_cpu(cpu) { 67 void *scratch, *old_scratch; 68 69 scratch = kmalloc_node(size, GFP_KERNEL, cpu_to_node(cpu)); 70 if (!scratch) { 71 err = -ENOMEM; 72 break; 73 } 74 75 old_scratch = rcu_replace_pointer(per_cpu(sigpool_scratch, cpu), 76 scratch, lockdep_is_held(&cpool_mutex)); 77 if (!cpu_online(cpu) || !old_scratch) { 78 kfree(old_scratch); 79 continue; 80 } 81 stf->scratches[stf->cnt++] = old_scratch; 82 } 83 cpus_read_unlock(); 84 if (!err) 85 __scratch_size = size; 86 87 call_rcu(&stf->rcu, free_old_scratches); 88 return err; 89 } 90 91 static void sigpool_scratch_free(void) 92 { 93 int cpu; 94 95 for_each_possible_cpu(cpu) 96 kfree(rcu_replace_pointer(per_cpu(sigpool_scratch, cpu), 97 NULL, lockdep_is_held(&cpool_mutex))); 98 __scratch_size = 0; 99 } 100 101 static int __cpool_try_clone(struct crypto_ahash *hash) 102 { 103 struct crypto_ahash *tmp; 104 105 tmp = crypto_clone_ahash(hash); 106 if (IS_ERR(tmp)) 107 return PTR_ERR(tmp); 108 109 crypto_free_ahash(tmp); 110 return 0; 111 } 112 113 static int __cpool_alloc_ahash(struct sigpool_entry *e, const char *alg) 114 { 115 struct crypto_ahash *cpu0_hash; 116 int ret; 117 118 e->alg = kstrdup(alg, GFP_KERNEL); 119 if (!e->alg) 120 return -ENOMEM; 121 122 cpu0_hash = crypto_alloc_ahash(alg, 0, CRYPTO_ALG_ASYNC); 123 if (IS_ERR(cpu0_hash)) { 124 ret = PTR_ERR(cpu0_hash); 125 goto out_free_alg; 126 } 127 128 e->needs_key = crypto_ahash_get_flags(cpu0_hash) & CRYPTO_TFM_NEED_KEY; 129 130 ret = __cpool_try_clone(cpu0_hash); 131 if (ret) 132 goto out_free_cpu0_hash; 133 e->hash = cpu0_hash; 134 kref_init(&e->kref); 135 return 0; 136 137 out_free_cpu0_hash: 138 crypto_free_ahash(cpu0_hash); 139 out_free_alg: 140 kfree(e->alg); 141 e->alg = NULL; 142 return ret; 143 } 144 145 /** 146 * tcp_sigpool_alloc_ahash - allocates pool for ahash requests 147 * @alg: name of async hash algorithm 148 * @scratch_size: reserve a tcp_sigpool::scratch buffer of this size 149 */ 150 int tcp_sigpool_alloc_ahash(const char *alg, size_t scratch_size) 151 { 152 int i, ret; 153 154 /* slow-path */ 155 mutex_lock(&cpool_mutex); 156 ret = sigpool_reserve_scratch(scratch_size); 157 if (ret) 158 goto out; 159 for (i = 0; i < cpool_populated; i++) { 160 if (!cpool[i].alg) 161 continue; 162 if (strcmp(cpool[i].alg, alg)) 163 continue; 164 165 if (kref_read(&cpool[i].kref) > 0) 166 kref_get(&cpool[i].kref); 167 else 168 kref_init(&cpool[i].kref); 169 ret = i; 170 goto out; 171 } 172 173 for (i = 0; i < cpool_populated; i++) { 174 if (!cpool[i].alg) 175 break; 176 } 177 if (i >= CPOOL_SIZE) { 178 ret = -ENOSPC; 179 goto out; 180 } 181 182 ret = __cpool_alloc_ahash(&cpool[i], alg); 183 if (!ret) { 184 ret = i; 185 if (i == cpool_populated) 186 cpool_populated++; 187 } 188 out: 189 mutex_unlock(&cpool_mutex); 190 return ret; 191 } 192 EXPORT_SYMBOL_GPL(tcp_sigpool_alloc_ahash); 193 194 static void __cpool_free_entry(struct sigpool_entry *e) 195 { 196 crypto_free_ahash(e->hash); 197 kfree(e->alg); 198 memset(e, 0, sizeof(*e)); 199 } 200 201 static void cpool_cleanup_work_cb(struct work_struct *work) 202 { 203 bool free_scratch = true; 204 unsigned int i; 205 206 mutex_lock(&cpool_mutex); 207 for (i = 0; i < cpool_populated; i++) { 208 if (kref_read(&cpool[i].kref) > 0) { 209 free_scratch = false; 210 continue; 211 } 212 if (!cpool[i].alg) 213 continue; 214 __cpool_free_entry(&cpool[i]); 215 } 216 if (free_scratch) 217 sigpool_scratch_free(); 218 mutex_unlock(&cpool_mutex); 219 } 220 221 static DECLARE_WORK(cpool_cleanup_work, cpool_cleanup_work_cb); 222 static void cpool_schedule_cleanup(struct kref *kref) 223 { 224 schedule_work(&cpool_cleanup_work); 225 } 226 227 /** 228 * tcp_sigpool_release - decreases number of users for a pool. If it was 229 * the last user of the pool, releases any memory that was consumed. 230 * @id: tcp_sigpool that was previously allocated by tcp_sigpool_alloc_ahash() 231 */ 232 void tcp_sigpool_release(unsigned int id) 233 { 234 if (WARN_ON_ONCE(id >= cpool_populated || !cpool[id].alg)) 235 return; 236 237 /* slow-path */ 238 kref_put(&cpool[id].kref, cpool_schedule_cleanup); 239 } 240 EXPORT_SYMBOL_GPL(tcp_sigpool_release); 241 242 /** 243 * tcp_sigpool_get - increases number of users (refcounter) for a pool 244 * @id: tcp_sigpool that was previously allocated by tcp_sigpool_alloc_ahash() 245 */ 246 void tcp_sigpool_get(unsigned int id) 247 { 248 if (WARN_ON_ONCE(id >= cpool_populated || !cpool[id].alg)) 249 return; 250 kref_get(&cpool[id].kref); 251 } 252 EXPORT_SYMBOL_GPL(tcp_sigpool_get); 253 254 int tcp_sigpool_start(unsigned int id, struct tcp_sigpool *c) __cond_acquires(RCU_BH) 255 { 256 struct crypto_ahash *hash; 257 258 rcu_read_lock_bh(); 259 if (WARN_ON_ONCE(id >= cpool_populated || !cpool[id].alg)) { 260 rcu_read_unlock_bh(); 261 return -EINVAL; 262 } 263 264 hash = crypto_clone_ahash(cpool[id].hash); 265 if (IS_ERR(hash)) { 266 rcu_read_unlock_bh(); 267 return PTR_ERR(hash); 268 } 269 270 c->req = ahash_request_alloc(hash, GFP_ATOMIC); 271 if (!c->req) { 272 crypto_free_ahash(hash); 273 rcu_read_unlock_bh(); 274 return -ENOMEM; 275 } 276 ahash_request_set_callback(c->req, 0, NULL, NULL); 277 278 /* Pairs with tcp_sigpool_reserve_scratch(), scratch area is 279 * valid (allocated) until tcp_sigpool_end(). 280 */ 281 c->scratch = rcu_dereference_bh(*this_cpu_ptr(&sigpool_scratch)); 282 return 0; 283 } 284 EXPORT_SYMBOL_GPL(tcp_sigpool_start); 285 286 void tcp_sigpool_end(struct tcp_sigpool *c) __releases(RCU_BH) 287 { 288 struct crypto_ahash *hash = crypto_ahash_reqtfm(c->req); 289 290 rcu_read_unlock_bh(); 291 ahash_request_free(c->req); 292 crypto_free_ahash(hash); 293 } 294 EXPORT_SYMBOL_GPL(tcp_sigpool_end); 295 296 /** 297 * tcp_sigpool_algo - return algorithm of tcp_sigpool 298 * @id: tcp_sigpool that was previously allocated by tcp_sigpool_alloc_ahash() 299 * @buf: buffer to return name of algorithm 300 * @buf_len: size of @buf 301 */ 302 size_t tcp_sigpool_algo(unsigned int id, char *buf, size_t buf_len) 303 { 304 if (WARN_ON_ONCE(id >= cpool_populated || !cpool[id].alg)) 305 return -EINVAL; 306 307 return strscpy(buf, cpool[id].alg, buf_len); 308 } 309 EXPORT_SYMBOL_GPL(tcp_sigpool_algo); 310 311 /** 312 * tcp_sigpool_hash_skb_data - hash data in skb with initialized tcp_sigpool 313 * @hp: tcp_sigpool pointer 314 * @skb: buffer to add sign for 315 * @header_len: TCP header length for this segment 316 */ 317 int tcp_sigpool_hash_skb_data(struct tcp_sigpool *hp, 318 const struct sk_buff *skb, 319 unsigned int header_len) 320 { 321 const unsigned int head_data_len = skb_headlen(skb) > header_len ? 322 skb_headlen(skb) - header_len : 0; 323 const struct skb_shared_info *shi = skb_shinfo(skb); 324 const struct tcphdr *tp = tcp_hdr(skb); 325 struct ahash_request *req = hp->req; 326 struct sk_buff *frag_iter; 327 struct scatterlist sg; 328 unsigned int i; 329 330 sg_init_table(&sg, 1); 331 332 sg_set_buf(&sg, ((u8 *)tp) + header_len, head_data_len); 333 ahash_request_set_crypt(req, &sg, NULL, head_data_len); 334 if (crypto_ahash_update(req)) 335 return 1; 336 337 for (i = 0; i < shi->nr_frags; ++i) { 338 const skb_frag_t *f = &shi->frags[i]; 339 unsigned int offset = skb_frag_off(f); 340 struct page *page; 341 342 page = skb_frag_page(f) + (offset >> PAGE_SHIFT); 343 sg_set_page(&sg, page, skb_frag_size(f), offset_in_page(offset)); 344 ahash_request_set_crypt(req, &sg, NULL, skb_frag_size(f)); 345 if (crypto_ahash_update(req)) 346 return 1; 347 } 348 349 skb_walk_frags(skb, frag_iter) 350 if (tcp_sigpool_hash_skb_data(hp, frag_iter, 0)) 351 return 1; 352 353 return 0; 354 } 355 EXPORT_SYMBOL(tcp_sigpool_hash_skb_data); 356 357 MODULE_LICENSE("GPL"); 358 MODULE_DESCRIPTION("Per-CPU pool of crypto requests"); 359