1 // SPDX-License-Identifier: GPL-2.0-only 2 /* 3 * Copyright (C) 2024, SUSE LLC 4 * 5 * Authors: Enzo Matsumiya <ematsumiya@suse.de> 6 * 7 * Implementation of the LZ77 "plain" compression algorithm, as per MS-XCA spec. 8 */ 9 #include <linux/slab.h> 10 #include <linux/sizes.h> 11 #include <linux/count_zeros.h> 12 #include <linux/unaligned.h> 13 14 #include "lz77.h" 15 16 /* 17 * Compression parameters. 18 */ 19 #define LZ77_MATCH_MIN_LEN 4 20 #define LZ77_MATCH_MIN_DIST 1 21 #define LZ77_MATCH_MAX_DIST SZ_1K 22 #define LZ77_HASH_LOG 15 23 #define LZ77_HASH_SIZE (1 << LZ77_HASH_LOG) 24 #define LZ77_STEP_SIZE sizeof(u64) 25 26 static __always_inline u8 lz77_read8(const u8 *ptr) 27 { 28 return get_unaligned(ptr); 29 } 30 31 static __always_inline u64 lz77_read64(const u64 *ptr) 32 { 33 return get_unaligned(ptr); 34 } 35 36 static __always_inline void lz77_write8(u8 *ptr, u8 v) 37 { 38 put_unaligned(v, ptr); 39 } 40 41 static __always_inline void lz77_write16(u16 *ptr, u16 v) 42 { 43 put_unaligned_le16(v, ptr); 44 } 45 46 static __always_inline void lz77_write32(u32 *ptr, u32 v) 47 { 48 put_unaligned_le32(v, ptr); 49 } 50 51 static __always_inline u32 lz77_match_len(const void *wnd, const void *cur, const void *end) 52 { 53 const void *start = cur; 54 u64 diff; 55 56 /* Safe for a do/while because otherwise we wouldn't reach here from the main loop. */ 57 do { 58 diff = lz77_read64(cur) ^ lz77_read64(wnd); 59 if (!diff) { 60 cur += LZ77_STEP_SIZE; 61 wnd += LZ77_STEP_SIZE; 62 63 continue; 64 } 65 66 /* This computes the number of common bytes in @diff. */ 67 cur += count_trailing_zeros(diff) >> 3; 68 69 return (cur - start); 70 } while (likely(cur + LZ77_STEP_SIZE < end)); 71 72 while (cur < end && lz77_read8(cur++) == lz77_read8(wnd++)) 73 ; 74 75 return (cur - start); 76 } 77 78 static __always_inline void *lz77_write_match(void *dst, void **nib, u32 dist, u32 len) 79 { 80 len -= 3; 81 dist--; 82 dist <<= 3; 83 84 if (len < 7) { 85 lz77_write16(dst, dist + len); 86 87 return dst + 2; 88 } 89 90 dist |= 7; 91 lz77_write16(dst, dist); 92 dst += 2; 93 len -= 7; 94 95 if (!*nib) { 96 lz77_write8(dst, umin(len, 15)); 97 *nib = dst; 98 dst++; 99 } else { 100 u8 *b = *nib; 101 102 lz77_write8(b, *b | umin(len, 15) << 4); 103 *nib = NULL; 104 } 105 106 if (len < 15) 107 return dst; 108 109 len -= 15; 110 if (len < 255) { 111 lz77_write8(dst, len); 112 113 return dst + 1; 114 } 115 116 lz77_write8(dst, 0xff); 117 dst++; 118 len += 7 + 15; 119 if (len <= 0xffff) { 120 lz77_write16(dst, len); 121 122 return dst + 2; 123 } 124 125 lz77_write16(dst, 0); 126 dst += 2; 127 lz77_write32(dst, len); 128 129 return dst + 4; 130 } 131 132 noinline int lz77_compress(const void *src, u32 slen, void *dst, u32 *dlen) 133 { 134 const void *srcp, *end; 135 void *dstp, *nib, *flag_pos; 136 u32 flag_count = 0; 137 long flag = 0; 138 u64 *htable; 139 140 srcp = src; 141 end = src + slen; 142 dstp = dst; 143 nib = NULL; 144 flag_pos = dstp; 145 dstp += 4; 146 147 htable = kvcalloc(LZ77_HASH_SIZE, sizeof(*htable), GFP_KERNEL); 148 if (!htable) 149 return -ENOMEM; 150 151 /* Main loop. */ 152 do { 153 u32 dist, len = 0; 154 const void *wnd; 155 u64 hash; 156 157 hash = ((lz77_read64(srcp) << 24) * 889523592379ULL) >> (64 - LZ77_HASH_LOG); 158 wnd = src + htable[hash]; 159 htable[hash] = srcp - src; 160 dist = srcp - wnd; 161 162 if (dist && dist < LZ77_MATCH_MAX_DIST) 163 len = lz77_match_len(wnd, srcp, end); 164 165 if (len < LZ77_MATCH_MIN_LEN) { 166 lz77_write8(dstp, lz77_read8(srcp)); 167 168 dstp++; 169 srcp++; 170 171 flag <<= 1; 172 flag_count++; 173 if (flag_count == 32) { 174 lz77_write32(flag_pos, flag); 175 flag_count = 0; 176 flag_pos = dstp; 177 dstp += 4; 178 } 179 180 continue; 181 } 182 183 /* 184 * Bail out if @dstp reached >= 7/8 of @slen -- already compressed badly, not worth 185 * going further. 186 */ 187 if (unlikely(dstp - dst >= slen - (slen >> 3))) { 188 *dlen = slen; 189 goto out; 190 } 191 192 dstp = lz77_write_match(dstp, &nib, dist, len); 193 srcp += len; 194 195 flag = (flag << 1) | 1; 196 flag_count++; 197 if (flag_count == 32) { 198 lz77_write32(flag_pos, flag); 199 flag_count = 0; 200 flag_pos = dstp; 201 dstp += 4; 202 } 203 } while (likely(srcp + LZ77_STEP_SIZE < end)); 204 205 while (srcp < end) { 206 u32 c = umin(end - srcp, 32 - flag_count); 207 208 memcpy(dstp, srcp, c); 209 210 dstp += c; 211 srcp += c; 212 213 flag <<= c; 214 flag_count += c; 215 if (flag_count == 32) { 216 lz77_write32(flag_pos, flag); 217 flag_count = 0; 218 flag_pos = dstp; 219 dstp += 4; 220 } 221 } 222 223 flag <<= (32 - flag_count); 224 flag |= (1 << (32 - flag_count)) - 1; 225 lz77_write32(flag_pos, flag); 226 227 *dlen = dstp - dst; 228 out: 229 kvfree(htable); 230 231 if (*dlen < slen) 232 return 0; 233 234 return -EMSGSIZE; 235 } 236