1 // SPDX-License-Identifier: GPL-2.0-only 2 /* 3 * Copyright (C) 2024, SUSE LLC 4 * 5 * Authors: Enzo Matsumiya <ematsumiya@suse.de> 6 * 7 * This file implements I/O compression support for SMB2 messages (SMB 3.1.1 only). 8 * See compress/ for implementation details of each algorithm. 9 * 10 * References: 11 * MS-SMB2 "3.1.4.4 Compressing the Message" 12 * MS-SMB2 "3.1.5.3 Decompressing the Chained Message" 13 * MS-XCA - for details of the supported algorithms 14 */ 15 #include <linux/slab.h> 16 #include <linux/kernel.h> 17 #include <linux/uio.h> 18 #include <linux/sort.h> 19 20 #include "cifsglob.h" 21 #include "../common/smb2pdu.h" 22 #include "cifsproto.h" 23 #include "smb2proto.h" 24 25 #include "compress/lz77.h" 26 #include "compress.h" 27 28 /* 29 * The heuristic_*() functions below try to determine data compressibility. 30 * 31 * Derived from fs/btrfs/compression.c, changing coding style, some parameters, and removing 32 * unused parts. 33 * 34 * Read that file for better and more detailed explanation of the calculations. 35 * 36 * The algorithms are ran in a collected sample of the input (uncompressed) data. 37 * The sample is formed of 2K reads in PAGE_SIZE intervals, with a maximum size of 4M. 38 * 39 * Parsing the sample goes from "low-hanging fruits" (fastest algorithms, likely compressible) 40 * to "need more analysis" (likely uncompressible). 41 */ 42 43 struct bucket { 44 unsigned int count; 45 }; 46 47 /** 48 * has_low_entropy() - Compute Shannon entropy of the sampled data. 49 * @bkt: Bytes counts of the sample. 50 * @slen: Size of the sample. 51 * 52 * Return: true if the level (percentage of number of bits that would be required to 53 * compress the data) is below the minimum threshold. 54 * 55 * Note: 56 * There _is_ an entropy level here that's > 65 (minimum threshold) that would indicate a 57 * possibility of compression, but compressing, or even further analysing, it would waste so much 58 * resources that it's simply not worth it. 59 * 60 * Also Shannon entropy is the last computed heuristic; if we got this far and ended up 61 * with uncertainty, just stay on the safe side and call it uncompressible. 62 */ 63 static bool has_low_entropy(struct bucket *bkt, size_t slen) 64 { 65 const size_t threshold = 65, max_entropy = 8 * ilog2(16); 66 size_t i, p, p2, len, sum = 0; 67 68 #define pow4(n) (n * n * n * n) 69 len = ilog2(pow4(slen)); 70 71 for (i = 0; i < 256 && bkt[i].count > 0; i++) { 72 p = bkt[i].count; 73 p2 = ilog2(pow4(p)); 74 sum += p * (len - p2); 75 } 76 77 sum /= slen; 78 79 return ((sum * 100 / max_entropy) <= threshold); 80 } 81 82 #define BYTE_DIST_BAD 0 83 #define BYTE_DIST_GOOD 1 84 #define BYTE_DIST_MAYBE 2 85 /** 86 * calc_byte_distribution() - Compute byte distribution on the sampled data. 87 * @bkt: Byte counts of the sample. 88 * @slen: Size of the sample. 89 * 90 * Return: 91 * BYTE_DIST_BAD: A "hard no" for compression -- a computed uniform distribution of 92 * the bytes (e.g. random or encrypted data). 93 * BYTE_DIST_GOOD: High probability (normal (Gaussian) distribution) of the data being 94 * compressible. 95 * BYTE_DIST_MAYBE: When computed byte distribution resulted in "low > n < high" 96 * grounds. has_low_entropy() should be used for a final decision. 97 */ 98 static int calc_byte_distribution(struct bucket *bkt, size_t slen) 99 { 100 const size_t low = 64, high = 200, threshold = slen * 90 / 100; 101 size_t sum = 0; 102 int i; 103 104 for (i = 0; i < low; i++) 105 sum += bkt[i].count; 106 107 if (sum > threshold) 108 return BYTE_DIST_BAD; 109 110 for (; i < high && bkt[i].count > 0; i++) { 111 sum += bkt[i].count; 112 if (sum > threshold) 113 break; 114 } 115 116 if (i <= low) 117 return BYTE_DIST_GOOD; 118 119 if (i >= high) 120 return BYTE_DIST_BAD; 121 122 return BYTE_DIST_MAYBE; 123 } 124 125 static bool is_mostly_ascii(const struct bucket *bkt) 126 { 127 size_t count = 0; 128 int i; 129 130 for (i = 0; i < 256; i++) 131 if (bkt[i].count > 0) 132 /* Too many non-ASCII (0-63) bytes. */ 133 if (++count > 64) 134 return false; 135 136 return true; 137 } 138 139 static bool has_repeated_data(const u8 *sample, size_t len) 140 { 141 size_t s = len / 2; 142 143 return (!memcmp(&sample[0], &sample[s], s)); 144 } 145 146 static int cmp_bkt(const void *_a, const void *_b) 147 { 148 const struct bucket *a = _a, *b = _b; 149 150 /* Reverse sort. */ 151 if (a->count > b->count) 152 return -1; 153 154 return 1; 155 } 156 157 /* 158 * Collect some 2K samples with 2K gaps between. 159 */ 160 static int collect_sample(const struct iov_iter *source, ssize_t max, u8 *sample) 161 { 162 struct iov_iter iter = *source; 163 size_t s = 0; 164 165 while (iov_iter_count(&iter) >= SZ_2K) { 166 size_t part = umin(umin(iov_iter_count(&iter), SZ_2K), max); 167 size_t n; 168 169 n = copy_from_iter(sample + s, part, &iter); 170 if (n != part) 171 return -EFAULT; 172 173 s += n; 174 max -= n; 175 176 if (iov_iter_count(&iter) < PAGE_SIZE - SZ_2K) 177 break; 178 179 iov_iter_advance(&iter, SZ_2K); 180 } 181 182 return s; 183 } 184 185 /** 186 * is_compressible() - Determines if a chunk of data is compressible. 187 * @data: Iterator containing uncompressed data. 188 * 189 * Return: true if @data is compressible, false otherwise. 190 * 191 * Tests shows that this function is quite reliable in predicting data compressibility, 192 * matching close to 1:1 with the behaviour of LZ77 compression success and failures. 193 */ 194 static bool is_compressible(const struct iov_iter *data) 195 { 196 const size_t read_size = SZ_2K, bkt_size = 256, max = SZ_4M; 197 struct bucket *bkt = NULL; 198 size_t len; 199 u8 *sample; 200 bool ret = false; 201 int i; 202 203 /* Preventive double check -- already checked in should_compress(). */ 204 len = iov_iter_count(data); 205 if (unlikely(len < read_size)) 206 return ret; 207 208 if (len - read_size > max) 209 len = max; 210 211 sample = kvzalloc(len, GFP_KERNEL); 212 if (!sample) { 213 WARN_ON_ONCE(1); 214 215 return ret; 216 } 217 218 /* Sample 2K bytes per page of the uncompressed data. */ 219 i = collect_sample(data, len, sample); 220 if (i <= 0) { 221 WARN_ON_ONCE(1); 222 223 goto out; 224 } 225 226 len = i; 227 ret = true; 228 229 if (has_repeated_data(sample, len)) 230 goto out; 231 232 bkt = kcalloc(bkt_size, sizeof(*bkt), GFP_KERNEL); 233 if (!bkt) { 234 WARN_ON_ONCE(1); 235 ret = false; 236 237 goto out; 238 } 239 240 for (i = 0; i < len; i++) 241 bkt[sample[i]].count++; 242 243 if (is_mostly_ascii(bkt)) 244 goto out; 245 246 /* Sort in descending order */ 247 sort(bkt, bkt_size, sizeof(*bkt), cmp_bkt, NULL); 248 249 i = calc_byte_distribution(bkt, len); 250 if (i != BYTE_DIST_MAYBE) { 251 ret = !!i; 252 253 goto out; 254 } 255 256 ret = has_low_entropy(bkt, len); 257 out: 258 kvfree(sample); 259 kfree(bkt); 260 261 return ret; 262 } 263 264 bool should_compress(const struct cifs_tcon *tcon, const struct smb_rqst *rq) 265 { 266 const struct smb2_hdr *shdr = rq->rq_iov->iov_base; 267 268 if (unlikely(!tcon || !tcon->ses || !tcon->ses->server)) 269 return false; 270 271 if (!tcon->ses->server->compression.enabled) 272 return false; 273 274 if (!(tcon->share_flags & SMB2_SHAREFLAG_COMPRESS_DATA)) 275 return false; 276 277 if (shdr->Command == SMB2_WRITE) { 278 const struct smb2_write_req *wreq = rq->rq_iov->iov_base; 279 280 if (le32_to_cpu(wreq->Length) < SMB_COMPRESS_MIN_LEN) 281 return false; 282 283 return is_compressible(&rq->rq_iter); 284 } 285 286 return (shdr->Command == SMB2_READ); 287 } 288 289 int smb_compress(struct TCP_Server_Info *server, struct smb_rqst *rq, compress_send_fn send_fn) 290 { 291 struct iov_iter iter; 292 u32 slen, dlen; 293 void *src, *dst = NULL; 294 int ret; 295 296 if (!server || !rq || !rq->rq_iov || !rq->rq_iov->iov_base) 297 return -EINVAL; 298 299 if (rq->rq_iov->iov_len != sizeof(struct smb2_write_req)) 300 return -EINVAL; 301 302 slen = iov_iter_count(&rq->rq_iter); 303 src = kvzalloc(slen, GFP_KERNEL); 304 if (!src) { 305 ret = -ENOMEM; 306 goto err_free; 307 } 308 309 /* Keep the original iter intact. */ 310 iter = rq->rq_iter; 311 312 if (!copy_from_iter_full(src, slen, &iter)) { 313 ret = -EIO; 314 goto err_free; 315 } 316 317 /* 318 * This is just overprovisioning, as the algorithm will error out if @dst reaches 7/8 319 * of @slen. 320 */ 321 dlen = slen; 322 dst = kvzalloc(dlen, GFP_KERNEL); 323 if (!dst) { 324 ret = -ENOMEM; 325 goto err_free; 326 } 327 328 ret = lz77_compress(src, slen, dst, &dlen); 329 if (!ret) { 330 struct smb2_compression_hdr hdr = { 0 }; 331 struct smb_rqst comp_rq = { .rq_nvec = 3, }; 332 struct kvec iov[3]; 333 334 hdr.ProtocolId = SMB2_COMPRESSION_TRANSFORM_ID; 335 hdr.OriginalCompressedSegmentSize = cpu_to_le32(slen); 336 hdr.CompressionAlgorithm = SMB3_COMPRESS_LZ77; 337 hdr.Flags = SMB2_COMPRESSION_FLAG_NONE; 338 hdr.Offset = cpu_to_le32(rq->rq_iov[0].iov_len); 339 340 iov[0].iov_base = &hdr; 341 iov[0].iov_len = sizeof(hdr); 342 iov[1] = rq->rq_iov[0]; 343 iov[2].iov_base = dst; 344 iov[2].iov_len = dlen; 345 346 comp_rq.rq_iov = iov; 347 348 ret = send_fn(server, 1, &comp_rq); 349 } else if (ret == -EMSGSIZE || dlen >= slen) { 350 ret = send_fn(server, 1, rq); 351 } 352 err_free: 353 kvfree(dst); 354 kvfree(src); 355 356 return ret; 357 } 358