1 // SPDX-License-Identifier: GPL-2.0-only 2 // Copyright (C) 2019-2020 Arm Ltd. 3 4 #include <linux/compiler.h> 5 #include <linux/kasan-checks.h> 6 #include <linux/kernel.h> 7 8 #include <net/checksum.h> 9 10 static u64 accumulate(u64 sum, u64 data) 11 { 12 sum += data; 13 if (sum < data) 14 sum += 1; 15 return sum; 16 } 17 18 /* 19 * We over-read the buffer and this makes KASAN unhappy. Instead, disable 20 * instrumentation and call kasan explicitly. 21 */ 22 unsigned int __no_sanitize_address do_csum(const unsigned char *buff, int len) 23 { 24 unsigned int offset, shift, sum; 25 const u64 *ptr; 26 u64 data, sum64 = 0; 27 28 if (unlikely(len == 0)) 29 return 0; 30 31 offset = (unsigned long)buff & 7; 32 /* 33 * This is to all intents and purposes safe, since rounding down cannot 34 * result in a different page or cache line being accessed, and @buff 35 * should absolutely not be pointing to anything read-sensitive. We do, 36 * however, have to be careful not to piss off KASAN, which means using 37 * unchecked reads to accommodate the head and tail, for which we'll 38 * compensate with an explicit check up-front. 39 */ 40 kasan_check_read(buff, len); 41 ptr = (u64 *)(buff - offset); 42 len = len + offset - 8; 43 44 /* 45 * Head: zero out any excess leading bytes. Shifting back by the same 46 * amount should be at least as fast as any other way of handling the 47 * odd/even alignment, and means we can ignore it until the very end. 48 */ 49 shift = offset * 8; 50 data = *ptr++; 51 data = (data >> shift) << shift; 52 53 /* 54 * Body: straightforward aligned loads from here on (the paired loads 55 * underlying the quadword type still only need dword alignment). The 56 * main loop strictly excludes the tail, so the second loop will always 57 * run at least once. 58 */ 59 while (unlikely(len > 64)) { 60 __uint128_t tmp1, tmp2, tmp3, tmp4; 61 62 tmp1 = *(__uint128_t *)ptr; 63 tmp2 = *(__uint128_t *)(ptr + 2); 64 tmp3 = *(__uint128_t *)(ptr + 4); 65 tmp4 = *(__uint128_t *)(ptr + 6); 66 67 len -= 64; 68 ptr += 8; 69 70 /* This is the "don't dump the carry flag into a GPR" idiom */ 71 tmp1 += (tmp1 >> 64) | (tmp1 << 64); 72 tmp2 += (tmp2 >> 64) | (tmp2 << 64); 73 tmp3 += (tmp3 >> 64) | (tmp3 << 64); 74 tmp4 += (tmp4 >> 64) | (tmp4 << 64); 75 tmp1 = ((tmp1 >> 64) << 64) | (tmp2 >> 64); 76 tmp1 += (tmp1 >> 64) | (tmp1 << 64); 77 tmp3 = ((tmp3 >> 64) << 64) | (tmp4 >> 64); 78 tmp3 += (tmp3 >> 64) | (tmp3 << 64); 79 tmp1 = ((tmp1 >> 64) << 64) | (tmp3 >> 64); 80 tmp1 += (tmp1 >> 64) | (tmp1 << 64); 81 tmp1 = ((tmp1 >> 64) << 64) | sum64; 82 tmp1 += (tmp1 >> 64) | (tmp1 << 64); 83 sum64 = tmp1 >> 64; 84 } 85 while (len > 8) { 86 __uint128_t tmp; 87 88 sum64 = accumulate(sum64, data); 89 tmp = *(__uint128_t *)ptr; 90 91 len -= 16; 92 ptr += 2; 93 94 data = tmp >> 64; 95 sum64 = accumulate(sum64, tmp); 96 } 97 if (len > 0) { 98 sum64 = accumulate(sum64, data); 99 data = *ptr; 100 len -= 8; 101 } 102 /* 103 * Tail: zero any over-read bytes similarly to the head, again 104 * preserving odd/even alignment. 105 */ 106 shift = len * -8; 107 data = (data << shift) >> shift; 108 sum64 = accumulate(sum64, data); 109 110 /* Finally, folding */ 111 sum64 += (sum64 >> 32) | (sum64 << 32); 112 sum = sum64 >> 32; 113 sum += (sum >> 16) | (sum << 16); 114 if (offset & 1) 115 return (u16)swab32(sum); 116 117 return sum >> 16; 118 } 119 120 __sum16 csum_ipv6_magic(const struct in6_addr *saddr, 121 const struct in6_addr *daddr, 122 __u32 len, __u8 proto, __wsum csum) 123 { 124 __uint128_t src, dst; 125 u64 sum = (__force u64)csum; 126 127 src = *(const __uint128_t *)saddr->s6_addr; 128 dst = *(const __uint128_t *)daddr->s6_addr; 129 130 sum += (__force u32)htonl(len); 131 sum += (u32)proto << 24; 132 src += (src >> 64) | (src << 64); 133 dst += (dst >> 64) | (dst << 64); 134 135 sum = accumulate(sum, src >> 64); 136 sum = accumulate(sum, dst >> 64); 137 138 sum += ((sum >> 32) | (sum << 32)); 139 return csum_fold((__force __wsum)(sum >> 32)); 140 } 141 EXPORT_SYMBOL(csum_ipv6_magic); 142