xref: /linux/arch/loongarch/lib/csum.c (revision 25489a4f556414445d342951615178368ee45cde)
1 // SPDX-License-Identifier: GPL-2.0-only
2 // Copyright (C) 2019-2020 Arm Ltd.
3 
4 #include <linux/compiler.h>
5 #include <linux/export.h>
6 #include <linux/kasan-checks.h>
7 #include <linux/kernel.h>
8 
9 #include <net/checksum.h>
10 
11 static u64 accumulate(u64 sum, u64 data)
12 {
13 	sum += data;
14 	if (sum < data)
15 		sum += 1;
16 	return sum;
17 }
18 
19 /*
20  * We over-read the buffer and this makes KASAN unhappy. Instead, disable
21  * instrumentation and call kasan explicitly.
22  */
23 unsigned int __no_sanitize_address do_csum(const unsigned char *buff, int len)
24 {
25 	unsigned int offset, shift, sum;
26 	const u64 *ptr;
27 	u64 data, sum64 = 0;
28 
29 	if (unlikely(len <= 0))
30 		return 0;
31 
32 	offset = (unsigned long)buff & 7;
33 	/*
34 	 * This is to all intents and purposes safe, since rounding down cannot
35 	 * result in a different page or cache line being accessed, and @buff
36 	 * should absolutely not be pointing to anything read-sensitive. We do,
37 	 * however, have to be careful not to piss off KASAN, which means using
38 	 * unchecked reads to accommodate the head and tail, for which we'll
39 	 * compensate with an explicit check up-front.
40 	 */
41 	kasan_check_read(buff, len);
42 	ptr = (u64 *)(buff - offset);
43 	len = len + offset - 8;
44 
45 	/*
46 	 * Head: zero out any excess leading bytes. Shifting back by the same
47 	 * amount should be at least as fast as any other way of handling the
48 	 * odd/even alignment, and means we can ignore it until the very end.
49 	 */
50 	shift = offset * 8;
51 	data = *ptr++;
52 	data = (data >> shift) << shift;
53 
54 	/*
55 	 * Body: straightforward aligned loads from here on (the paired loads
56 	 * underlying the quadword type still only need dword alignment). The
57 	 * main loop strictly excludes the tail, so the second loop will always
58 	 * run at least once.
59 	 */
60 	while (unlikely(len > 64)) {
61 		__uint128_t tmp1, tmp2, tmp3, tmp4;
62 
63 		tmp1 = *(__uint128_t *)ptr;
64 		tmp2 = *(__uint128_t *)(ptr + 2);
65 		tmp3 = *(__uint128_t *)(ptr + 4);
66 		tmp4 = *(__uint128_t *)(ptr + 6);
67 
68 		len -= 64;
69 		ptr += 8;
70 
71 		/* This is the "don't dump the carry flag into a GPR" idiom */
72 		tmp1 += (tmp1 >> 64) | (tmp1 << 64);
73 		tmp2 += (tmp2 >> 64) | (tmp2 << 64);
74 		tmp3 += (tmp3 >> 64) | (tmp3 << 64);
75 		tmp4 += (tmp4 >> 64) | (tmp4 << 64);
76 		tmp1 = ((tmp1 >> 64) << 64) | (tmp2 >> 64);
77 		tmp1 += (tmp1 >> 64) | (tmp1 << 64);
78 		tmp3 = ((tmp3 >> 64) << 64) | (tmp4 >> 64);
79 		tmp3 += (tmp3 >> 64) | (tmp3 << 64);
80 		tmp1 = ((tmp1 >> 64) << 64) | (tmp3 >> 64);
81 		tmp1 += (tmp1 >> 64) | (tmp1 << 64);
82 		tmp1 = ((tmp1 >> 64) << 64) | sum64;
83 		tmp1 += (tmp1 >> 64) | (tmp1 << 64);
84 		sum64 = tmp1 >> 64;
85 	}
86 	while (len > 8) {
87 		__uint128_t tmp;
88 
89 		sum64 = accumulate(sum64, data);
90 		tmp = *(__uint128_t *)ptr;
91 
92 		len -= 16;
93 		ptr += 2;
94 
95 		data = tmp >> 64;
96 		sum64 = accumulate(sum64, tmp);
97 	}
98 	if (len > 0) {
99 		sum64 = accumulate(sum64, data);
100 		data = *ptr;
101 		len -= 8;
102 	}
103 	/*
104 	 * Tail: zero any over-read bytes similarly to the head, again
105 	 * preserving odd/even alignment.
106 	 */
107 	shift = len * -8;
108 	data = (data << shift) >> shift;
109 	sum64 = accumulate(sum64, data);
110 
111 	/* Finally, folding */
112 	sum64 += (sum64 >> 32) | (sum64 << 32);
113 	sum = sum64 >> 32;
114 	sum += (sum >> 16) | (sum << 16);
115 	if (offset & 1)
116 		return (u16)swab32(sum);
117 
118 	return sum >> 16;
119 }
120 
121 __sum16 csum_ipv6_magic(const struct in6_addr *saddr,
122 			const struct in6_addr *daddr,
123 			__u32 len, __u8 proto, __wsum csum)
124 {
125 	__uint128_t src, dst;
126 	u64 sum = (__force u64)csum;
127 
128 	src = *(const __uint128_t *)saddr->s6_addr;
129 	dst = *(const __uint128_t *)daddr->s6_addr;
130 
131 	sum += (__force u32)htonl(len);
132 	sum += (u32)proto << 24;
133 	src += (src >> 64) | (src << 64);
134 	dst += (dst >> 64) | (dst << 64);
135 
136 	sum = accumulate(sum, src >> 64);
137 	sum = accumulate(sum, dst >> 64);
138 
139 	sum += ((sum >> 32) | (sum << 32));
140 	return csum_fold((__force __wsum)(sum >> 32));
141 }
142 EXPORT_SYMBOL(csum_ipv6_magic);
143