xref: /linux/fs/smb/client/compress/lz77.c (revision add452d09a38c7a7c44aea55c1015392cebf9fa7)
1 // SPDX-License-Identifier: GPL-2.0-only
2 /*
3  * Copyright (C) 2024, SUSE LLC
4  *
5  * Authors: Enzo Matsumiya <ematsumiya@suse.de>
6  *
7  * Implementation of the LZ77 "plain" compression algorithm, as per MS-XCA spec.
8  */
9 #include <linux/slab.h>
10 #include <linux/sizes.h>
11 #include <linux/count_zeros.h>
12 #include <linux/unaligned.h>
13 
14 #include "lz77.h"
15 
16 /*
17  * Compression parameters.
18  */
19 #define LZ77_MATCH_MIN_LEN	4
20 #define LZ77_MATCH_MIN_DIST	1
21 #define LZ77_MATCH_MAX_DIST	SZ_1K
22 #define LZ77_HASH_LOG		15
23 #define LZ77_HASH_SIZE		(1 << LZ77_HASH_LOG)
24 #define LZ77_STEP_SIZE		sizeof(u64)
25 
26 static __always_inline u8 lz77_read8(const u8 *ptr)
27 {
28 	return get_unaligned(ptr);
29 }
30 
31 static __always_inline u64 lz77_read64(const u64 *ptr)
32 {
33 	return get_unaligned(ptr);
34 }
35 
36 static __always_inline void lz77_write8(u8 *ptr, u8 v)
37 {
38 	put_unaligned(v, ptr);
39 }
40 
41 static __always_inline void lz77_write16(u16 *ptr, u16 v)
42 {
43 	put_unaligned_le16(v, ptr);
44 }
45 
46 static __always_inline void lz77_write32(u32 *ptr, u32 v)
47 {
48 	put_unaligned_le32(v, ptr);
49 }
50 
51 static __always_inline u32 lz77_match_len(const void *wnd, const void *cur, const void *end)
52 {
53 	const void *start = cur;
54 	u64 diff;
55 
56 	/* Safe for a do/while because otherwise we wouldn't reach here from the main loop. */
57 	do {
58 		diff = lz77_read64(cur) ^ lz77_read64(wnd);
59 		if (!diff) {
60 			cur += LZ77_STEP_SIZE;
61 			wnd += LZ77_STEP_SIZE;
62 
63 			continue;
64 		}
65 
66 		/* This computes the number of common bytes in @diff. */
67 		cur += count_trailing_zeros(diff) >> 3;
68 
69 		return (cur - start);
70 	} while (likely(cur + LZ77_STEP_SIZE < end));
71 
72 	while (cur < end && lz77_read8(cur++) == lz77_read8(wnd++))
73 		;
74 
75 	return (cur - start);
76 }
77 
78 static __always_inline void *lz77_write_match(void *dst, void **nib, u32 dist, u32 len)
79 {
80 	len -= 3;
81 	dist--;
82 	dist <<= 3;
83 
84 	if (len < 7) {
85 		lz77_write16(dst, dist + len);
86 
87 		return dst + 2;
88 	}
89 
90 	dist |= 7;
91 	lz77_write16(dst, dist);
92 	dst += 2;
93 	len -= 7;
94 
95 	if (!*nib) {
96 		lz77_write8(dst, umin(len, 15));
97 		*nib = dst;
98 		dst++;
99 	} else {
100 		u8 *b = *nib;
101 
102 		lz77_write8(b, *b | umin(len, 15) << 4);
103 		*nib = NULL;
104 	}
105 
106 	if (len < 15)
107 		return dst;
108 
109 	len -= 15;
110 	if (len < 255) {
111 		lz77_write8(dst, len);
112 
113 		return dst + 1;
114 	}
115 
116 	lz77_write8(dst, 0xff);
117 	dst++;
118 	len += 7 + 15;
119 	if (len <= 0xffff) {
120 		lz77_write16(dst, len);
121 
122 		return dst + 2;
123 	}
124 
125 	lz77_write16(dst, 0);
126 	dst += 2;
127 	lz77_write32(dst, len);
128 
129 	return dst + 4;
130 }
131 
132 noinline int lz77_compress(const void *src, u32 slen, void *dst, u32 *dlen)
133 {
134 	const void *srcp, *end;
135 	void *dstp, *nib, *flag_pos;
136 	u32 flag_count = 0;
137 	long flag = 0;
138 	u64 *htable;
139 
140 	srcp = src;
141 	end = src + slen;
142 	dstp = dst;
143 	nib = NULL;
144 	flag_pos = dstp;
145 	dstp += 4;
146 
147 	htable = kvcalloc(LZ77_HASH_SIZE, sizeof(*htable), GFP_KERNEL);
148 	if (!htable)
149 		return -ENOMEM;
150 
151 	/* Main loop. */
152 	do {
153 		u32 dist, len = 0;
154 		const void *wnd;
155 		u64 hash;
156 
157 		hash = ((lz77_read64(srcp) << 24) * 889523592379ULL) >> (64 - LZ77_HASH_LOG);
158 		wnd = src + htable[hash];
159 		htable[hash] = srcp - src;
160 		dist = srcp - wnd;
161 
162 		if (dist && dist < LZ77_MATCH_MAX_DIST)
163 			len = lz77_match_len(wnd, srcp, end);
164 
165 		if (len < LZ77_MATCH_MIN_LEN) {
166 			lz77_write8(dstp, lz77_read8(srcp));
167 
168 			dstp++;
169 			srcp++;
170 
171 			flag <<= 1;
172 			flag_count++;
173 			if (flag_count == 32) {
174 				lz77_write32(flag_pos, flag);
175 				flag_count = 0;
176 				flag_pos = dstp;
177 				dstp += 4;
178 			}
179 
180 			continue;
181 		}
182 
183 		/*
184 		 * Bail out if @dstp reached >= 7/8 of @slen -- already compressed badly, not worth
185 		 * going further.
186 		 */
187 		if (unlikely(dstp - dst >= slen - (slen >> 3))) {
188 			*dlen = slen;
189 			goto out;
190 		}
191 
192 		dstp = lz77_write_match(dstp, &nib, dist, len);
193 		srcp += len;
194 
195 		flag = (flag << 1) | 1;
196 		flag_count++;
197 		if (flag_count == 32) {
198 			lz77_write32(flag_pos, flag);
199 			flag_count = 0;
200 			flag_pos = dstp;
201 			dstp += 4;
202 		}
203 	} while (likely(srcp + LZ77_STEP_SIZE < end));
204 
205 	while (srcp < end) {
206 		u32 c = umin(end - srcp, 32 - flag_count);
207 
208 		memcpy(dstp, srcp, c);
209 
210 		dstp += c;
211 		srcp += c;
212 
213 		flag <<= c;
214 		flag_count += c;
215 		if (flag_count == 32) {
216 			lz77_write32(flag_pos, flag);
217 			flag_count = 0;
218 			flag_pos = dstp;
219 			dstp += 4;
220 		}
221 	}
222 
223 	flag <<= (32 - flag_count);
224 	flag |= (1 << (32 - flag_count)) - 1;
225 	lz77_write32(flag_pos, flag);
226 
227 	*dlen = dstp - dst;
228 out:
229 	kvfree(htable);
230 
231 	if (*dlen < slen)
232 		return 0;
233 
234 	return -EMSGSIZE;
235 }
236