1 // SPDX-License-Identifier: GPL-2.0-or-later 2 /* 3 * IP Payload Compression Protocol (IPComp) - RFC3173. 4 * 5 * Copyright (c) 2003 James Morris <jmorris@intercode.com.au> 6 * Copyright (c) 2003-2008 Herbert Xu <herbert@gondor.apana.org.au> 7 * 8 * Todo: 9 * - Tunable compression parameters. 10 * - Compression stats. 11 * - Adaptive compression. 12 */ 13 14 #include <linux/crypto.h> 15 #include <linux/err.h> 16 #include <linux/list.h> 17 #include <linux/module.h> 18 #include <linux/mutex.h> 19 #include <linux/percpu.h> 20 #include <linux/slab.h> 21 #include <linux/smp.h> 22 #include <linux/vmalloc.h> 23 #include <net/ip.h> 24 #include <net/ipcomp.h> 25 #include <net/xfrm.h> 26 27 struct ipcomp_tfms { 28 struct list_head list; 29 struct crypto_comp * __percpu *tfms; 30 int users; 31 }; 32 33 static DEFINE_MUTEX(ipcomp_resource_mutex); 34 static void * __percpu *ipcomp_scratches; 35 static int ipcomp_scratch_users; 36 static LIST_HEAD(ipcomp_tfms_list); 37 38 static int ipcomp_decompress(struct xfrm_state *x, struct sk_buff *skb) 39 { 40 struct ipcomp_data *ipcd = x->data; 41 const int plen = skb->len; 42 int dlen = IPCOMP_SCRATCH_SIZE; 43 const u8 *start = skb->data; 44 u8 *scratch = *this_cpu_ptr(ipcomp_scratches); 45 struct crypto_comp *tfm = *this_cpu_ptr(ipcd->tfms); 46 int err = crypto_comp_decompress(tfm, start, plen, scratch, &dlen); 47 int len; 48 49 if (err) 50 return err; 51 52 if (dlen < (plen + sizeof(struct ip_comp_hdr))) 53 return -EINVAL; 54 55 len = dlen - plen; 56 if (len > skb_tailroom(skb)) 57 len = skb_tailroom(skb); 58 59 __skb_put(skb, len); 60 61 len += plen; 62 skb_copy_to_linear_data(skb, scratch, len); 63 64 while ((scratch += len, dlen -= len) > 0) { 65 skb_frag_t *frag; 66 struct page *page; 67 68 if (WARN_ON(skb_shinfo(skb)->nr_frags >= MAX_SKB_FRAGS)) 69 return -EMSGSIZE; 70 71 frag = skb_shinfo(skb)->frags + skb_shinfo(skb)->nr_frags; 72 page = alloc_page(GFP_ATOMIC); 73 74 if (!page) 75 return -ENOMEM; 76 77 __skb_frag_set_page(frag, page); 78 79 len = PAGE_SIZE; 80 if (dlen < len) 81 len = dlen; 82 83 skb_frag_off_set(frag, 0); 84 skb_frag_size_set(frag, len); 85 memcpy(skb_frag_address(frag), scratch, len); 86 87 skb->truesize += len; 88 skb->data_len += len; 89 skb->len += len; 90 91 skb_shinfo(skb)->nr_frags++; 92 } 93 94 return 0; 95 } 96 97 int ipcomp_input(struct xfrm_state *x, struct sk_buff *skb) 98 { 99 int nexthdr; 100 int err = -ENOMEM; 101 struct ip_comp_hdr *ipch; 102 103 if (skb_linearize_cow(skb)) 104 goto out; 105 106 skb->ip_summed = CHECKSUM_NONE; 107 108 /* Remove ipcomp header and decompress original payload */ 109 ipch = (void *)skb->data; 110 nexthdr = ipch->nexthdr; 111 112 skb->transport_header = skb->network_header + sizeof(*ipch); 113 __skb_pull(skb, sizeof(*ipch)); 114 err = ipcomp_decompress(x, skb); 115 if (err) 116 goto out; 117 118 err = nexthdr; 119 120 out: 121 return err; 122 } 123 EXPORT_SYMBOL_GPL(ipcomp_input); 124 125 static int ipcomp_compress(struct xfrm_state *x, struct sk_buff *skb) 126 { 127 struct ipcomp_data *ipcd = x->data; 128 const int plen = skb->len; 129 int dlen = IPCOMP_SCRATCH_SIZE; 130 u8 *start = skb->data; 131 struct crypto_comp *tfm; 132 u8 *scratch; 133 int err; 134 135 local_bh_disable(); 136 scratch = *this_cpu_ptr(ipcomp_scratches); 137 tfm = *this_cpu_ptr(ipcd->tfms); 138 err = crypto_comp_compress(tfm, start, plen, scratch, &dlen); 139 if (err) 140 goto out; 141 142 if ((dlen + sizeof(struct ip_comp_hdr)) >= plen) { 143 err = -EMSGSIZE; 144 goto out; 145 } 146 147 memcpy(start + sizeof(struct ip_comp_hdr), scratch, dlen); 148 local_bh_enable(); 149 150 pskb_trim(skb, dlen + sizeof(struct ip_comp_hdr)); 151 return 0; 152 153 out: 154 local_bh_enable(); 155 return err; 156 } 157 158 int ipcomp_output(struct xfrm_state *x, struct sk_buff *skb) 159 { 160 int err; 161 struct ip_comp_hdr *ipch; 162 struct ipcomp_data *ipcd = x->data; 163 164 if (skb->len < ipcd->threshold) { 165 /* Don't bother compressing */ 166 goto out_ok; 167 } 168 169 if (skb_linearize_cow(skb)) 170 goto out_ok; 171 172 err = ipcomp_compress(x, skb); 173 174 if (err) { 175 goto out_ok; 176 } 177 178 /* Install ipcomp header, convert into ipcomp datagram. */ 179 ipch = ip_comp_hdr(skb); 180 ipch->nexthdr = *skb_mac_header(skb); 181 ipch->flags = 0; 182 ipch->cpi = htons((u16 )ntohl(x->id.spi)); 183 *skb_mac_header(skb) = IPPROTO_COMP; 184 out_ok: 185 skb_push(skb, -skb_network_offset(skb)); 186 return 0; 187 } 188 EXPORT_SYMBOL_GPL(ipcomp_output); 189 190 static void ipcomp_free_scratches(void) 191 { 192 int i; 193 void * __percpu *scratches; 194 195 if (--ipcomp_scratch_users) 196 return; 197 198 scratches = ipcomp_scratches; 199 if (!scratches) 200 return; 201 202 for_each_possible_cpu(i) 203 vfree(*per_cpu_ptr(scratches, i)); 204 205 free_percpu(scratches); 206 ipcomp_scratches = NULL; 207 } 208 209 static void * __percpu *ipcomp_alloc_scratches(void) 210 { 211 void * __percpu *scratches; 212 int i; 213 214 if (ipcomp_scratch_users++) 215 return ipcomp_scratches; 216 217 scratches = alloc_percpu(void *); 218 if (!scratches) 219 return NULL; 220 221 ipcomp_scratches = scratches; 222 223 for_each_possible_cpu(i) { 224 void *scratch; 225 226 scratch = vmalloc_node(IPCOMP_SCRATCH_SIZE, cpu_to_node(i)); 227 if (!scratch) 228 return NULL; 229 *per_cpu_ptr(scratches, i) = scratch; 230 } 231 232 return scratches; 233 } 234 235 static void ipcomp_free_tfms(struct crypto_comp * __percpu *tfms) 236 { 237 struct ipcomp_tfms *pos; 238 int cpu; 239 240 list_for_each_entry(pos, &ipcomp_tfms_list, list) { 241 if (pos->tfms == tfms) 242 break; 243 } 244 245 WARN_ON(list_entry_is_head(pos, &ipcomp_tfms_list, list)); 246 247 if (--pos->users) 248 return; 249 250 list_del(&pos->list); 251 kfree(pos); 252 253 if (!tfms) 254 return; 255 256 for_each_possible_cpu(cpu) { 257 struct crypto_comp *tfm = *per_cpu_ptr(tfms, cpu); 258 crypto_free_comp(tfm); 259 } 260 free_percpu(tfms); 261 } 262 263 static struct crypto_comp * __percpu *ipcomp_alloc_tfms(const char *alg_name) 264 { 265 struct ipcomp_tfms *pos; 266 struct crypto_comp * __percpu *tfms; 267 int cpu; 268 269 270 list_for_each_entry(pos, &ipcomp_tfms_list, list) { 271 struct crypto_comp *tfm; 272 273 /* This can be any valid CPU ID so we don't need locking. */ 274 tfm = this_cpu_read(*pos->tfms); 275 276 if (!strcmp(crypto_comp_name(tfm), alg_name)) { 277 pos->users++; 278 return pos->tfms; 279 } 280 } 281 282 pos = kmalloc(sizeof(*pos), GFP_KERNEL); 283 if (!pos) 284 return NULL; 285 286 pos->users = 1; 287 INIT_LIST_HEAD(&pos->list); 288 list_add(&pos->list, &ipcomp_tfms_list); 289 290 pos->tfms = tfms = alloc_percpu(struct crypto_comp *); 291 if (!tfms) 292 goto error; 293 294 for_each_possible_cpu(cpu) { 295 struct crypto_comp *tfm = crypto_alloc_comp(alg_name, 0, 296 CRYPTO_ALG_ASYNC); 297 if (IS_ERR(tfm)) 298 goto error; 299 *per_cpu_ptr(tfms, cpu) = tfm; 300 } 301 302 return tfms; 303 304 error: 305 ipcomp_free_tfms(tfms); 306 return NULL; 307 } 308 309 static void ipcomp_free_data(struct ipcomp_data *ipcd) 310 { 311 if (ipcd->tfms) 312 ipcomp_free_tfms(ipcd->tfms); 313 ipcomp_free_scratches(); 314 } 315 316 void ipcomp_destroy(struct xfrm_state *x) 317 { 318 struct ipcomp_data *ipcd = x->data; 319 if (!ipcd) 320 return; 321 xfrm_state_delete_tunnel(x); 322 mutex_lock(&ipcomp_resource_mutex); 323 ipcomp_free_data(ipcd); 324 mutex_unlock(&ipcomp_resource_mutex); 325 kfree(ipcd); 326 } 327 EXPORT_SYMBOL_GPL(ipcomp_destroy); 328 329 int ipcomp_init_state(struct xfrm_state *x, struct netlink_ext_ack *extack) 330 { 331 int err; 332 struct ipcomp_data *ipcd; 333 struct xfrm_algo_desc *calg_desc; 334 335 err = -EINVAL; 336 if (!x->calg) { 337 NL_SET_ERR_MSG(extack, "Missing required compression algorithm"); 338 goto out; 339 } 340 341 if (x->encap) { 342 NL_SET_ERR_MSG(extack, "IPComp is not compatible with encapsulation"); 343 goto out; 344 } 345 346 err = -ENOMEM; 347 ipcd = kzalloc(sizeof(*ipcd), GFP_KERNEL); 348 if (!ipcd) 349 goto out; 350 351 mutex_lock(&ipcomp_resource_mutex); 352 if (!ipcomp_alloc_scratches()) 353 goto error; 354 355 ipcd->tfms = ipcomp_alloc_tfms(x->calg->alg_name); 356 if (!ipcd->tfms) 357 goto error; 358 mutex_unlock(&ipcomp_resource_mutex); 359 360 calg_desc = xfrm_calg_get_byname(x->calg->alg_name, 0); 361 BUG_ON(!calg_desc); 362 ipcd->threshold = calg_desc->uinfo.comp.threshold; 363 x->data = ipcd; 364 err = 0; 365 out: 366 return err; 367 368 error: 369 ipcomp_free_data(ipcd); 370 mutex_unlock(&ipcomp_resource_mutex); 371 kfree(ipcd); 372 goto out; 373 } 374 EXPORT_SYMBOL_GPL(ipcomp_init_state); 375 376 MODULE_LICENSE("GPL"); 377 MODULE_DESCRIPTION("IP Payload Compression Protocol (IPComp) - RFC3173"); 378 MODULE_AUTHOR("James Morris <jmorris@intercode.com.au>"); 379