1 // SPDX-License-Identifier: GPL-2.0-only 2 // Copyright (C) 2019-2020 Arm Ltd. 3 4 #include <linux/compiler.h> 5 #include <linux/export.h> 6 #include <linux/kasan-checks.h> 7 #include <linux/kernel.h> 8 9 #include <net/checksum.h> 10 11 static u64 accumulate(u64 sum, u64 data) 12 { 13 sum += data; 14 if (sum < data) 15 sum += 1; 16 return sum; 17 } 18 19 /* 20 * We over-read the buffer and this makes KASAN unhappy. Instead, disable 21 * instrumentation and call kasan explicitly. 22 */ 23 unsigned int __no_sanitize_address do_csum(const unsigned char *buff, int len) 24 { 25 unsigned int offset, shift, sum; 26 const u64 *ptr; 27 u64 data, sum64 = 0; 28 29 if (unlikely(len <= 0)) 30 return 0; 31 32 offset = (unsigned long)buff & 7; 33 /* 34 * This is to all intents and purposes safe, since rounding down cannot 35 * result in a different page or cache line being accessed, and @buff 36 * should absolutely not be pointing to anything read-sensitive. We do, 37 * however, have to be careful not to piss off KASAN, which means using 38 * unchecked reads to accommodate the head and tail, for which we'll 39 * compensate with an explicit check up-front. 40 */ 41 kasan_check_read(buff, len); 42 ptr = (u64 *)(buff - offset); 43 len = len + offset - 8; 44 45 /* 46 * Head: zero out any excess leading bytes. Shifting back by the same 47 * amount should be at least as fast as any other way of handling the 48 * odd/even alignment, and means we can ignore it until the very end. 49 */ 50 shift = offset * 8; 51 data = *ptr++; 52 data = (data >> shift) << shift; 53 54 /* 55 * Body: straightforward aligned loads from here on (the paired loads 56 * underlying the quadword type still only need dword alignment). The 57 * main loop strictly excludes the tail, so the second loop will always 58 * run at least once. 59 */ 60 while (unlikely(len > 64)) { 61 __uint128_t tmp1, tmp2, tmp3, tmp4; 62 63 tmp1 = *(__uint128_t *)ptr; 64 tmp2 = *(__uint128_t *)(ptr + 2); 65 tmp3 = *(__uint128_t *)(ptr + 4); 66 tmp4 = *(__uint128_t *)(ptr + 6); 67 68 len -= 64; 69 ptr += 8; 70 71 /* This is the "don't dump the carry flag into a GPR" idiom */ 72 tmp1 += (tmp1 >> 64) | (tmp1 << 64); 73 tmp2 += (tmp2 >> 64) | (tmp2 << 64); 74 tmp3 += (tmp3 >> 64) | (tmp3 << 64); 75 tmp4 += (tmp4 >> 64) | (tmp4 << 64); 76 tmp1 = ((tmp1 >> 64) << 64) | (tmp2 >> 64); 77 tmp1 += (tmp1 >> 64) | (tmp1 << 64); 78 tmp3 = ((tmp3 >> 64) << 64) | (tmp4 >> 64); 79 tmp3 += (tmp3 >> 64) | (tmp3 << 64); 80 tmp1 = ((tmp1 >> 64) << 64) | (tmp3 >> 64); 81 tmp1 += (tmp1 >> 64) | (tmp1 << 64); 82 tmp1 = ((tmp1 >> 64) << 64) | sum64; 83 tmp1 += (tmp1 >> 64) | (tmp1 << 64); 84 sum64 = tmp1 >> 64; 85 } 86 while (len > 8) { 87 __uint128_t tmp; 88 89 sum64 = accumulate(sum64, data); 90 tmp = *(__uint128_t *)ptr; 91 92 len -= 16; 93 ptr += 2; 94 95 data = tmp >> 64; 96 sum64 = accumulate(sum64, tmp); 97 } 98 if (len > 0) { 99 sum64 = accumulate(sum64, data); 100 data = *ptr; 101 len -= 8; 102 } 103 /* 104 * Tail: zero any over-read bytes similarly to the head, again 105 * preserving odd/even alignment. 106 */ 107 shift = len * -8; 108 data = (data << shift) >> shift; 109 sum64 = accumulate(sum64, data); 110 111 /* Finally, folding */ 112 sum64 += (sum64 >> 32) | (sum64 << 32); 113 sum = sum64 >> 32; 114 sum += (sum >> 16) | (sum << 16); 115 if (offset & 1) 116 return (u16)swab32(sum); 117 118 return sum >> 16; 119 } 120 121 __sum16 csum_ipv6_magic(const struct in6_addr *saddr, 122 const struct in6_addr *daddr, 123 __u32 len, __u8 proto, __wsum csum) 124 { 125 __uint128_t src, dst; 126 u64 sum = (__force u64)csum; 127 128 src = *(const __uint128_t *)saddr->s6_addr; 129 dst = *(const __uint128_t *)daddr->s6_addr; 130 131 sum += (__force u32)htonl(len); 132 sum += (u32)proto << 24; 133 src += (src >> 64) | (src << 64); 134 dst += (dst >> 64) | (dst << 64); 135 136 sum = accumulate(sum, src >> 64); 137 sum = accumulate(sum, dst >> 64); 138 139 sum += ((sum >> 32) | (sum << 32)); 140 return csum_fold((__force __wsum)(sum >> 32)); 141 } 142 EXPORT_SYMBOL(csum_ipv6_magic); 143