xref: /linux/fs/smb/client/compress.c (revision bba2c3615bd6cfee7456d1130f2e6b01b3f4e9ba)
1 // SPDX-License-Identifier: GPL-2.0-only
2 /*
3  * Copyright (C) 2024, SUSE LLC
4  *
5  * Authors: Enzo Matsumiya <ematsumiya@suse.de>
6  *
7  * This file implements I/O compression support for SMB2 messages (SMB 3.1.1 only).
8  * See compress/ for implementation details of each algorithm.
9  *
10  * References:
11  * MS-SMB2 "3.1.4.4 Compressing the Message"
12  * MS-SMB2 "3.1.5.3 Decompressing the Chained Message"
13  * MS-XCA - for details of the supported algorithms
14  */
15 #include <linux/slab.h>
16 #include <linux/kernel.h>
17 #include <linux/uio.h>
18 #include <linux/sort.h>
19 
20 #include "cifsglob.h"
21 #include "../common/smb2pdu.h"
22 #include "cifsproto.h"
23 #include "smb2proto.h"
24 
25 #include "../common/compress/lz77.h"
26 #include "compress.h"
27 
28 /*
29  * The heuristic_*() functions below try to determine data compressibility.
30  *
31  * Derived from fs/btrfs/compression.c, changing coding style, some parameters, and removing
32  * unused parts.
33  *
34  * Read that file for better and more detailed explanation of the calculations.
35  *
36  * The algorithms are ran in a collected sample of the input (uncompressed) data.
37  * The sample is formed of 2K reads in PAGE_SIZE intervals, with a maximum size of 4M.
38  *
39  * Parsing the sample goes from "low-hanging fruits" (fastest algorithms, likely compressible)
40  * to "need more analysis" (likely uncompressible).
41  */
42 
43 struct bucket {
44 	unsigned int count;
45 };
46 
47 static inline size_t pow4(size_t n)
48 {
49 	return n * n * n * n;
50 }
51 
52 /*
53  * has_low_entropy() - Compute Shannon entropy of the sampled data.
54  * @bkt:	Bytes counts of the sample.
55  * @slen:	Size of the sample.
56  *
57  * Return: true if the level (percentage of number of bits that would be required to
58  *	   compress the data) is below the minimum threshold.
59  *
60  * Note:
61  * There _is_ an entropy level here that's > 65 (minimum threshold) that would indicate a
62  * possibility of compression, but compressing, or even further analysing, it would waste so much
63  * resources that it's simply not worth it.
64  *
65  * Also Shannon entropy is the last computed heuristic; if we got this far and ended up
66  * with uncertainty, just stay on the safe side and call it uncompressible.
67  */
68 static bool has_low_entropy(struct bucket *bkt, size_t slen)
69 {
70 	const size_t threshold = 65, max_entropy = 8 * ilog2(16);
71 	size_t i, p, p2, len, sum = 0;
72 
73 	len = ilog2(pow4(slen));
74 
75 	for (i = 0; i < 256 && bkt[i].count > 0; i++) {
76 		p = bkt[i].count;
77 		p2 = ilog2(pow4(p));
78 		sum += p * (len - p2);
79 	}
80 
81 	sum /= slen;
82 
83 	return ((sum * 100 / max_entropy) <= threshold);
84 }
85 
86 #define BYTE_DIST_BAD		0
87 #define BYTE_DIST_GOOD		1
88 #define BYTE_DIST_MAYBE		2
89 /*
90  * calc_byte_distribution() - Compute byte distribution on the sampled data.
91  * @bkt:	Byte counts of the sample.
92  * @slen:	Size of the sample.
93  *
94  * Return:
95  * BYTE_DIST_BAD:	A "hard no" for compression -- a computed uniform distribution of
96  *			the bytes (e.g. random or encrypted data).
97  * BYTE_DIST_GOOD:	High probability (normal (Gaussian) distribution) of the data being
98  *			compressible.
99  * BYTE_DIST_MAYBE:	When computed byte distribution resulted in "low > n < high"
100  *			grounds.  has_low_entropy() should be used for a final decision.
101  */
102 static int calc_byte_distribution(struct bucket *bkt, size_t slen)
103 {
104 	const size_t low = 64, high = 200, threshold = slen * 90 / 100;
105 	size_t sum = 0;
106 	int i;
107 
108 	for (i = 0; i < low; i++)
109 		sum += bkt[i].count;
110 
111 	if (sum > threshold)
112 		return BYTE_DIST_BAD;
113 
114 	for (; i < high && bkt[i].count > 0; i++) {
115 		sum += bkt[i].count;
116 		if (sum > threshold)
117 			break;
118 	}
119 
120 	if (i <= low)
121 		return BYTE_DIST_GOOD;
122 
123 	if (i >= high)
124 		return BYTE_DIST_BAD;
125 
126 	return BYTE_DIST_MAYBE;
127 }
128 
129 static bool is_mostly_ascii(const struct bucket *bkt)
130 {
131 	size_t count = 0;
132 	int i;
133 
134 	for (i = 0; i < 256; i++)
135 		if (bkt[i].count > 0)
136 			/* Too many non-ASCII (0-63) bytes. */
137 			if (++count > 64)
138 				return false;
139 
140 	return true;
141 }
142 
143 static bool has_repeated_data(const u8 *sample, size_t len)
144 {
145 	size_t s = len / 2;
146 
147 	return (!memcmp(&sample[0], &sample[s], s));
148 }
149 
150 static int cmp_bkt(const void *_a, const void *_b)
151 {
152 	const struct bucket *a = _a, *b = _b;
153 
154 	/* Reverse sort. */
155 	if (a->count > b->count)
156 		return -1;
157 
158 	return 1;
159 }
160 
161 /*
162  * Collect some 2K samples with 2K gaps between.
163  */
164 static int collect_sample(const struct iov_iter *source, ssize_t max, u8 *sample)
165 {
166 	struct iov_iter iter = *source;
167 	size_t s = 0;
168 
169 	while (iov_iter_count(&iter) >= SZ_2K) {
170 		size_t part = umin(umin(iov_iter_count(&iter), SZ_2K), max);
171 		size_t n;
172 
173 		n = copy_from_iter(sample + s, part, &iter);
174 		if (n != part)
175 			return -EFAULT;
176 
177 		s += n;
178 		max -= n;
179 
180 		if (iov_iter_count(&iter) < PAGE_SIZE - SZ_2K)
181 			break;
182 
183 		iov_iter_advance(&iter, SZ_2K);
184 	}
185 
186 	return s;
187 }
188 
189 /*
190  * is_compressible() - Determines if a chunk of data is compressible.
191  * @data: Iterator containing uncompressed data.
192  *
193  * Return: true if @data is compressible, false otherwise.
194  *
195  * Tests shows that this function is quite reliable in predicting data compressibility,
196  * matching close to 1:1 with the behaviour of LZ77 compression success and failures.
197  */
198 static bool is_compressible(const struct iov_iter *data)
199 {
200 	const size_t read_size = SZ_2K, bkt_size = 256, max = SZ_4M;
201 	struct bucket *bkt = NULL;
202 	size_t len;
203 	u8 *sample;
204 	bool ret = false;
205 	int i;
206 
207 	/* Preventive double check -- already checked in should_compress(). */
208 	len = iov_iter_count(data);
209 	if (unlikely(len < read_size))
210 		return ret;
211 
212 	if (len - read_size > max)
213 		len = max;
214 
215 	sample = kvzalloc(len, GFP_KERNEL);
216 	if (!sample) {
217 		WARN_ON_ONCE(1);
218 
219 		return ret;
220 	}
221 
222 	/* Sample 2K bytes per page of the uncompressed data. */
223 	i = collect_sample(data, len, sample);
224 	if (i <= 0) {
225 		WARN_ON_ONCE(1);
226 
227 		goto out;
228 	}
229 
230 	len = i;
231 	ret = true;
232 
233 	if (has_repeated_data(sample, len))
234 		goto out;
235 
236 	bkt = kzalloc_objs(*bkt, bkt_size);
237 	if (!bkt) {
238 		WARN_ON_ONCE(1);
239 		ret = false;
240 
241 		goto out;
242 	}
243 
244 	for (i = 0; i < len; i++)
245 		bkt[sample[i]].count++;
246 
247 	if (is_mostly_ascii(bkt))
248 		goto out;
249 
250 	/* Sort in descending order */
251 	sort(bkt, bkt_size, sizeof(*bkt), cmp_bkt, NULL);
252 
253 	i = calc_byte_distribution(bkt, len);
254 	if (i != BYTE_DIST_MAYBE) {
255 		ret = !!i;
256 
257 		goto out;
258 	}
259 
260 	ret = has_low_entropy(bkt, len);
261 out:
262 	kvfree(sample);
263 	kfree(bkt);
264 
265 	return ret;
266 }
267 
268 /*
269  * should_compress() - Determines if a request (write) or the response to a
270  *		       request (read) should be compressed.
271  * @tcon: tcon of the request is being sent to
272  * @rqst: request to evaluate
273  *
274  * Return: true iff:
275  * - compression was successfully negotiated with server
276  * - server has enabled compression for the share
277  * - it's a read or write request
278  * - (write only) request length is >= SMB_COMPRESS_MIN_LEN
279  * - (write only) is_compressible() returns 1
280  *
281  * Return false otherwise.
282  */
283 bool should_compress(const struct cifs_tcon *tcon, const struct smb_rqst *rq)
284 {
285 	const struct smb2_hdr *shdr = rq->rq_iov->iov_base;
286 
287 	if (unlikely(!tcon || !tcon->ses || !tcon->ses->server))
288 		return false;
289 
290 	if (!tcon->ses->server->compression.enabled)
291 		return false;
292 
293 	if (!(tcon->share_flags & SMB2_SHAREFLAG_COMPRESS_DATA))
294 		return false;
295 
296 	if (shdr->Command == SMB2_WRITE) {
297 		const struct smb2_write_req *wreq = rq->rq_iov->iov_base;
298 
299 		if (le32_to_cpu(wreq->Length) < SMB_COMPRESS_MIN_LEN)
300 			return false;
301 
302 		return is_compressible(&rq->rq_iter);
303 	}
304 
305 	return (shdr->Command == SMB2_READ);
306 }
307 
308 int smb_compress(struct TCP_Server_Info *server, struct smb_rqst *rq, compress_send_fn send_fn)
309 {
310 	struct iov_iter iter;
311 	u32 slen, dlen;
312 	void *src, *dst = NULL;
313 	int ret;
314 
315 	if (!server || !rq || !rq->rq_iov || !rq->rq_iov->iov_base)
316 		return -EINVAL;
317 
318 	if (rq->rq_iov->iov_len != sizeof(struct smb2_write_req))
319 		return -EINVAL;
320 
321 	slen = iov_iter_count(&rq->rq_iter);
322 	src = kvzalloc(slen, GFP_KERNEL);
323 	if (!src) {
324 		ret = -ENOMEM;
325 		goto err_free;
326 	}
327 
328 	/* Keep the original iter intact. */
329 	iter = rq->rq_iter;
330 
331 	if (!copy_from_iter_full(src, slen, &iter)) {
332 		ret = smb_EIO(smb_eio_trace_compress_copy);
333 		goto err_free;
334 	}
335 
336 	dlen = smb_lz77_compressed_alloc_size(slen);
337 	dst = kvzalloc(dlen, GFP_KERNEL);
338 	if (!dst) {
339 		ret = -ENOMEM;
340 		goto err_free;
341 	}
342 
343 	ret = smb_lz77_compress(src, slen, dst, &dlen);
344 	if (!ret) {
345 		struct smb2_compression_hdr hdr = { 0 };
346 		struct smb_rqst comp_rq = { .rq_nvec = 3, };
347 		struct kvec iov[3];
348 
349 		hdr.ProtocolId = SMB2_COMPRESSION_TRANSFORM_ID;
350 		hdr.OriginalCompressedSegmentSize = cpu_to_le32(slen);
351 		hdr.CompressionAlgorithm = SMB3_COMPRESS_LZ77;
352 		hdr.Flags = SMB2_COMPRESSION_FLAG_NONE;
353 		hdr.Offset = cpu_to_le32(rq->rq_iov[0].iov_len);
354 
355 		iov[0].iov_base = &hdr;
356 		iov[0].iov_len = sizeof(hdr);
357 		iov[1] = rq->rq_iov[0];
358 		iov[2].iov_base = dst;
359 		iov[2].iov_len = dlen;
360 
361 		comp_rq.rq_iov = iov;
362 
363 		ret = send_fn(server, 1, &comp_rq);
364 	} else if (ret == -EMSGSIZE || dlen >= slen) {
365 		ret = send_fn(server, 1, rq);
366 	}
367 err_free:
368 	kvfree(dst);
369 	kvfree(src);
370 
371 	return ret;
372 }
373