xref: /linux/fs/smb/client/compress.c (revision 13d68a16430312fc21990f48326366eb73891202)
1 // SPDX-License-Identifier: GPL-2.0-only
2 /*
3  * Copyright (C) 2024, SUSE LLC
4  *
5  * Authors: Enzo Matsumiya <ematsumiya@suse.de>
6  *
7  * This file implements I/O compression support for SMB2 messages (SMB 3.1.1 only).
8  * See compress/ for implementation details of each algorithm.
9  *
10  * References:
11  * MS-SMB2 "3.1.4.4 Compressing the Message"
12  * MS-SMB2 "3.1.5.3 Decompressing the Chained Message"
13  * MS-XCA - for details of the supported algorithms
14  */
15 #include <linux/slab.h>
16 #include <linux/kernel.h>
17 #include <linux/uio.h>
18 #include <linux/sort.h>
19 
20 #include "cifsglob.h"
21 #include "../common/smb2pdu.h"
22 #include "cifsproto.h"
23 #include "smb2proto.h"
24 
25 #include "compress/lz77.h"
26 #include "compress.h"
27 
28 /*
29  * The heuristic_*() functions below try to determine data compressibility.
30  *
31  * Derived from fs/btrfs/compression.c, changing coding style, some parameters, and removing
32  * unused parts.
33  *
34  * Read that file for better and more detailed explanation of the calculations.
35  *
36  * The algorithms are ran in a collected sample of the input (uncompressed) data.
37  * The sample is formed of 2K reads in PAGE_SIZE intervals, with a maximum size of 4M.
38  *
39  * Parsing the sample goes from "low-hanging fruits" (fastest algorithms, likely compressible)
40  * to "need more analysis" (likely uncompressible).
41  */
42 
43 struct bucket {
44 	unsigned int count;
45 };
46 
47 /**
48  * has_low_entropy() - Compute Shannon entropy of the sampled data.
49  * @bkt:	Bytes counts of the sample.
50  * @slen:	Size of the sample.
51  *
52  * Return: true if the level (percentage of number of bits that would be required to
53  *	   compress the data) is below the minimum threshold.
54  *
55  * Note:
56  * There _is_ an entropy level here that's > 65 (minimum threshold) that would indicate a
57  * possibility of compression, but compressing, or even further analysing, it would waste so much
58  * resources that it's simply not worth it.
59  *
60  * Also Shannon entropy is the last computed heuristic; if we got this far and ended up
61  * with uncertainty, just stay on the safe side and call it uncompressible.
62  */
63 static bool has_low_entropy(struct bucket *bkt, size_t slen)
64 {
65 	const size_t threshold = 65, max_entropy = 8 * ilog2(16);
66 	size_t i, p, p2, len, sum = 0;
67 
68 #define pow4(n) (n * n * n * n)
69 	len = ilog2(pow4(slen));
70 
71 	for (i = 0; i < 256 && bkt[i].count > 0; i++) {
72 		p = bkt[i].count;
73 		p2 = ilog2(pow4(p));
74 		sum += p * (len - p2);
75 	}
76 
77 	sum /= slen;
78 
79 	return ((sum * 100 / max_entropy) <= threshold);
80 }
81 
82 #define BYTE_DIST_BAD		0
83 #define BYTE_DIST_GOOD		1
84 #define BYTE_DIST_MAYBE		2
85 /**
86  * calc_byte_distribution() - Compute byte distribution on the sampled data.
87  * @bkt:	Byte counts of the sample.
88  * @slen:	Size of the sample.
89  *
90  * Return:
91  * BYTE_DIST_BAD:	A "hard no" for compression -- a computed uniform distribution of
92  *			the bytes (e.g. random or encrypted data).
93  * BYTE_DIST_GOOD:	High probability (normal (Gaussian) distribution) of the data being
94  *			compressible.
95  * BYTE_DIST_MAYBE:	When computed byte distribution resulted in "low > n < high"
96  *			grounds.  has_low_entropy() should be used for a final decision.
97  */
98 static int calc_byte_distribution(struct bucket *bkt, size_t slen)
99 {
100 	const size_t low = 64, high = 200, threshold = slen * 90 / 100;
101 	size_t sum = 0;
102 	int i;
103 
104 	for (i = 0; i < low; i++)
105 		sum += bkt[i].count;
106 
107 	if (sum > threshold)
108 		return BYTE_DIST_BAD;
109 
110 	for (; i < high && bkt[i].count > 0; i++) {
111 		sum += bkt[i].count;
112 		if (sum > threshold)
113 			break;
114 	}
115 
116 	if (i <= low)
117 		return BYTE_DIST_GOOD;
118 
119 	if (i >= high)
120 		return BYTE_DIST_BAD;
121 
122 	return BYTE_DIST_MAYBE;
123 }
124 
125 static bool is_mostly_ascii(const struct bucket *bkt)
126 {
127 	size_t count = 0;
128 	int i;
129 
130 	for (i = 0; i < 256; i++)
131 		if (bkt[i].count > 0)
132 			/* Too many non-ASCII (0-63) bytes. */
133 			if (++count > 64)
134 				return false;
135 
136 	return true;
137 }
138 
139 static bool has_repeated_data(const u8 *sample, size_t len)
140 {
141 	size_t s = len / 2;
142 
143 	return (!memcmp(&sample[0], &sample[s], s));
144 }
145 
146 static int cmp_bkt(const void *_a, const void *_b)
147 {
148 	const struct bucket *a = _a, *b = _b;
149 
150 	/* Reverse sort. */
151 	if (a->count > b->count)
152 		return -1;
153 
154 	return 1;
155 }
156 
157 /*
158  * TODO:
159  * Support other iter types, if required.
160  * Only ITER_XARRAY is supported for now.
161  */
162 static int collect_sample(const struct iov_iter *iter, ssize_t max, u8 *sample)
163 {
164 	struct folio *folios[16], *folio;
165 	unsigned int nr, i, j, npages;
166 	loff_t start = iter->xarray_start + iter->iov_offset;
167 	pgoff_t last, index = start / PAGE_SIZE;
168 	size_t len, off, foff;
169 	ssize_t ret = 0;
170 	void *p;
171 	int s = 0;
172 
173 	last = (start + max - 1) / PAGE_SIZE;
174 	do {
175 		nr = xa_extract(iter->xarray, (void **)folios, index, last, ARRAY_SIZE(folios),
176 				XA_PRESENT);
177 		if (nr == 0)
178 			return -EIO;
179 
180 		for (i = 0; i < nr; i++) {
181 			folio = folios[i];
182 			npages = folio_nr_pages(folio);
183 			foff = start - folio_pos(folio);
184 			off = foff % PAGE_SIZE;
185 
186 			for (j = foff / PAGE_SIZE; j < npages; j++) {
187 				size_t len2;
188 
189 				len = min_t(size_t, max, PAGE_SIZE - off);
190 				len2 = min_t(size_t, len, SZ_2K);
191 
192 				p = kmap_local_page(folio_page(folio, j));
193 				memcpy(&sample[s], p, len2);
194 				kunmap_local(p);
195 
196 				if (ret < 0)
197 					return ret;
198 
199 				s += len2;
200 
201 				if (len2 < SZ_2K || s >= max - SZ_2K)
202 					return s;
203 
204 				max -= len;
205 				if (max <= 0)
206 					return s;
207 
208 				start += len;
209 				off = 0;
210 				index++;
211 			}
212 		}
213 	} while (nr == ARRAY_SIZE(folios));
214 
215 	return s;
216 }
217 
218 /**
219  * is_compressible() - Determines if a chunk of data is compressible.
220  * @data: Iterator containing uncompressed data.
221  *
222  * Return: true if @data is compressible, false otherwise.
223  *
224  * Tests shows that this function is quite reliable in predicting data compressibility,
225  * matching close to 1:1 with the behaviour of LZ77 compression success and failures.
226  */
227 static bool is_compressible(const struct iov_iter *data)
228 {
229 	const size_t read_size = SZ_2K, bkt_size = 256, max = SZ_4M;
230 	struct bucket *bkt = NULL;
231 	size_t len;
232 	u8 *sample;
233 	bool ret = false;
234 	int i;
235 
236 	/* Preventive double check -- already checked in should_compress(). */
237 	len = iov_iter_count(data);
238 	if (unlikely(len < read_size))
239 		return ret;
240 
241 	if (len - read_size > max)
242 		len = max;
243 
244 	sample = kvzalloc(len, GFP_KERNEL);
245 	if (!sample) {
246 		WARN_ON_ONCE(1);
247 
248 		return ret;
249 	}
250 
251 	/* Sample 2K bytes per page of the uncompressed data. */
252 	i = collect_sample(data, len, sample);
253 	if (i <= 0) {
254 		WARN_ON_ONCE(1);
255 
256 		goto out;
257 	}
258 
259 	len = i;
260 	ret = true;
261 
262 	if (has_repeated_data(sample, len))
263 		goto out;
264 
265 	bkt = kcalloc(bkt_size, sizeof(*bkt), GFP_KERNEL);
266 	if (!bkt) {
267 		WARN_ON_ONCE(1);
268 		ret = false;
269 
270 		goto out;
271 	}
272 
273 	for (i = 0; i < len; i++)
274 		bkt[sample[i]].count++;
275 
276 	if (is_mostly_ascii(bkt))
277 		goto out;
278 
279 	/* Sort in descending order */
280 	sort(bkt, bkt_size, sizeof(*bkt), cmp_bkt, NULL);
281 
282 	i = calc_byte_distribution(bkt, len);
283 	if (i != BYTE_DIST_MAYBE) {
284 		ret = !!i;
285 
286 		goto out;
287 	}
288 
289 	ret = has_low_entropy(bkt, len);
290 out:
291 	kvfree(sample);
292 	kfree(bkt);
293 
294 	return ret;
295 }
296 
297 bool should_compress(const struct cifs_tcon *tcon, const struct smb_rqst *rq)
298 {
299 	const struct smb2_hdr *shdr = rq->rq_iov->iov_base;
300 
301 	if (unlikely(!tcon || !tcon->ses || !tcon->ses->server))
302 		return false;
303 
304 	if (!tcon->ses->server->compression.enabled)
305 		return false;
306 
307 	if (!(tcon->share_flags & SMB2_SHAREFLAG_COMPRESS_DATA))
308 		return false;
309 
310 	if (shdr->Command == SMB2_WRITE) {
311 		const struct smb2_write_req *wreq = rq->rq_iov->iov_base;
312 
313 		if (le32_to_cpu(wreq->Length) < SMB_COMPRESS_MIN_LEN)
314 			return false;
315 
316 		return is_compressible(&rq->rq_iter);
317 	}
318 
319 	return (shdr->Command == SMB2_READ);
320 }
321 
322 int smb_compress(struct TCP_Server_Info *server, struct smb_rqst *rq, compress_send_fn send_fn)
323 {
324 	struct iov_iter iter;
325 	u32 slen, dlen;
326 	void *src, *dst = NULL;
327 	int ret;
328 
329 	if (!server || !rq || !rq->rq_iov || !rq->rq_iov->iov_base)
330 		return -EINVAL;
331 
332 	if (rq->rq_iov->iov_len != sizeof(struct smb2_write_req))
333 		return -EINVAL;
334 
335 	slen = iov_iter_count(&rq->rq_iter);
336 	src = kvzalloc(slen, GFP_KERNEL);
337 	if (!src) {
338 		ret = -ENOMEM;
339 		goto err_free;
340 	}
341 
342 	/* Keep the original iter intact. */
343 	iter = rq->rq_iter;
344 
345 	if (!copy_from_iter_full(src, slen, &iter)) {
346 		ret = -EIO;
347 		goto err_free;
348 	}
349 
350 	/*
351 	 * This is just overprovisioning, as the algorithm will error out if @dst reaches 7/8
352 	 * of @slen.
353 	 */
354 	dlen = slen;
355 	dst = kvzalloc(dlen, GFP_KERNEL);
356 	if (!dst) {
357 		ret = -ENOMEM;
358 		goto err_free;
359 	}
360 
361 	ret = lz77_compress(src, slen, dst, &dlen);
362 	if (!ret) {
363 		struct smb2_compression_hdr hdr = { 0 };
364 		struct smb_rqst comp_rq = { .rq_nvec = 3, };
365 		struct kvec iov[3];
366 
367 		hdr.ProtocolId = SMB2_COMPRESSION_TRANSFORM_ID;
368 		hdr.OriginalCompressedSegmentSize = cpu_to_le32(slen);
369 		hdr.CompressionAlgorithm = SMB3_COMPRESS_LZ77;
370 		hdr.Flags = SMB2_COMPRESSION_FLAG_NONE;
371 		hdr.Offset = cpu_to_le32(rq->rq_iov[0].iov_len);
372 
373 		iov[0].iov_base = &hdr;
374 		iov[0].iov_len = sizeof(hdr);
375 		iov[1] = rq->rq_iov[0];
376 		iov[2].iov_base = dst;
377 		iov[2].iov_len = dlen;
378 
379 		comp_rq.rq_iov = iov;
380 
381 		ret = send_fn(server, 1, &comp_rq);
382 	} else if (ret == -EMSGSIZE || dlen >= slen) {
383 		ret = send_fn(server, 1, rq);
384 	}
385 err_free:
386 	kvfree(dst);
387 	kvfree(src);
388 
389 	return ret;
390 }
391