1 /*
2 * Copyright (c) 2010, Oracle and/or its affiliates. All rights reserved.
3 */
4
5 /*
6 * This file contains code imported from the OFED rds source file message.c
7 * Oracle elects to have and use the contents of message.c under and governed
8 * by the OpenIB.org BSD license (see below for full license text). However,
9 * the following notice accompanied the original version of this file:
10 */
11
12 /*
13 * Copyright (c) 2006 Oracle. All rights reserved.
14 *
15 * This software is available to you under a choice of one of two
16 * licenses. You may choose to be licensed under the terms of the GNU
17 * General Public License (GPL) Version 2, available from the file
18 * COPYING in the main directory of this source tree, or the
19 * OpenIB.org BSD license below:
20 *
21 * Redistribution and use in source and binary forms, with or
22 * without modification, are permitted provided that the following
23 * conditions are met:
24 *
25 * - Redistributions of source code must retain the above
26 * copyright notice, this list of conditions and the following
27 * disclaimer.
28 *
29 * - Redistributions in binary form must reproduce the above
30 * copyright notice, this list of conditions and the following
31 * disclaimer in the documentation and/or other materials
32 * provided with the distribution.
33 *
34 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
35 * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
36 * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
37 * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
38 * BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
39 * ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
40 * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
41 * SOFTWARE.
42 *
43 */
44 #include <sys/rds.h>
45 #include <sys/containerof.h>
46
47 #include <sys/ib/clients/rdsv3/rdsv3.h>
48 #include <sys/ib/clients/rdsv3/rdma.h>
49 #include <sys/ib/clients/rdsv3/rdsv3_debug.h>
50
51 #ifndef __lock_lint
52 static unsigned int rdsv3_exthdr_size[__RDSV3_EXTHDR_MAX] = {
53 [RDSV3_EXTHDR_NONE] = 0,
54 [RDSV3_EXTHDR_VERSION] = sizeof (struct rdsv3_ext_header_version),
55 [RDSV3_EXTHDR_RDMA] = sizeof (struct rdsv3_ext_header_rdma),
56 [RDSV3_EXTHDR_RDMA_DEST] = sizeof (struct rdsv3_ext_header_rdma_dest),
57 };
58 #else
59 static unsigned int rdsv3_exthdr_size[__RDSV3_EXTHDR_MAX] = {
60 0,
61 sizeof (struct rdsv3_ext_header_version),
62 sizeof (struct rdsv3_ext_header_rdma),
63 sizeof (struct rdsv3_ext_header_rdma_dest),
64 };
65 #endif
66
67 void
rdsv3_message_addref(struct rdsv3_message * rm)68 rdsv3_message_addref(struct rdsv3_message *rm)
69 {
70 RDSV3_DPRINTF5("rdsv3_message_addref", "addref rm %p ref %d",
71 rm, atomic_get(&rm->m_refcount));
72 atomic_inc_32(&rm->m_refcount);
73 }
74
75 /*
76 * This relies on dma_map_sg() not touching sg[].page during merging.
77 */
78 static void
rdsv3_message_purge(struct rdsv3_message * rm)79 rdsv3_message_purge(struct rdsv3_message *rm)
80 {
81 unsigned long i;
82
83 RDSV3_DPRINTF4("rdsv3_message_purge", "Enter(rm: %p)", rm);
84
85 if (test_bit(RDSV3_MSG_PAGEVEC, &rm->m_flags))
86 return;
87
88 for (i = 0; i < rm->m_nents; i++) {
89 RDSV3_DPRINTF5("rdsv3_message_purge", "putting data page %p\n",
90 (void *)rdsv3_sg_page(&rm->m_sg[i]));
91 /* XXX will have to put_page for page refs */
92 kmem_free(rdsv3_sg_page(&rm->m_sg[i]),
93 rdsv3_sg_len(&rm->m_sg[i]));
94 }
95
96 if (rm->m_rdma_op)
97 rdsv3_rdma_free_op(rm->m_rdma_op);
98 if (rm->m_rdma_mr) {
99 struct rdsv3_mr *mr = rm->m_rdma_mr;
100 if (mr->r_refcount == 0) {
101 RDSV3_DPRINTF4("rdsv3_message_purge ASSERT 0",
102 "rm %p mr %p", rm, mr);
103 return;
104 }
105 if (mr->r_refcount == 0xdeadbeef) {
106 RDSV3_DPRINTF4("rdsv3_message_purge ASSERT deadbeef",
107 "rm %p mr %p", rm, mr);
108 return;
109 }
110 if (atomic_dec_and_test(&mr->r_refcount)) {
111 rm->m_rdma_mr = NULL;
112 __rdsv3_put_mr_final(mr);
113 }
114 }
115
116 RDSV3_DPRINTF4("rdsv3_message_purge", "Return(rm: %p)", rm);
117
118 }
119
120 void
rdsv3_message_put(struct rdsv3_message * rm)121 rdsv3_message_put(struct rdsv3_message *rm)
122 {
123 RDSV3_DPRINTF5("rdsv3_message_put",
124 "put rm %p ref %d\n", rm, atomic_get(&rm->m_refcount));
125
126 if (atomic_dec_and_test(&rm->m_refcount)) {
127 ASSERT(!list_link_active(&rm->m_sock_item));
128 ASSERT(!list_link_active(&rm->m_conn_item));
129 rdsv3_message_purge(rm);
130
131 kmem_free(rm, sizeof (struct rdsv3_message) +
132 (rm->m_nents * sizeof (struct rdsv3_scatterlist)));
133 }
134 }
135
136 void
rdsv3_message_inc_free(struct rdsv3_incoming * inc)137 rdsv3_message_inc_free(struct rdsv3_incoming *inc)
138 {
139 struct rdsv3_message *rm =
140 __containerof(inc, struct rdsv3_message, m_inc);
141 rdsv3_message_put(rm);
142 }
143
144 void
rdsv3_message_populate_header(struct rdsv3_header * hdr,uint16_be_t sport,uint16_be_t dport,uint64_t seq)145 rdsv3_message_populate_header(struct rdsv3_header *hdr, uint16_be_t sport,
146 uint16_be_t dport, uint64_t seq)
147 {
148 hdr->h_flags = 0;
149 hdr->h_sport = sport;
150 hdr->h_dport = dport;
151 hdr->h_sequence = htonll(seq);
152 hdr->h_exthdr[0] = RDSV3_EXTHDR_NONE;
153 }
154
155 int
rdsv3_message_add_extension(struct rdsv3_header * hdr,unsigned int type,const void * data,unsigned int len)156 rdsv3_message_add_extension(struct rdsv3_header *hdr,
157 unsigned int type, const void *data, unsigned int len)
158 {
159 unsigned int ext_len = sizeof (uint8_t) + len;
160 unsigned char *dst;
161
162 RDSV3_DPRINTF4("rdsv3_message_add_extension", "Enter");
163
164 /* For now, refuse to add more than one extension header */
165 if (hdr->h_exthdr[0] != RDSV3_EXTHDR_NONE)
166 return (0);
167
168 if (type >= __RDSV3_EXTHDR_MAX ||
169 len != rdsv3_exthdr_size[type])
170 return (0);
171
172 if (ext_len >= RDSV3_HEADER_EXT_SPACE)
173 return (0);
174 dst = hdr->h_exthdr;
175
176 *dst++ = type;
177 (void) memcpy(dst, data, len);
178
179 dst[len] = RDSV3_EXTHDR_NONE;
180
181 RDSV3_DPRINTF4("rdsv3_message_add_extension", "Return");
182 return (1);
183 }
184
185 /*
186 * If a message has extension headers, retrieve them here.
187 * Call like this:
188 *
189 * unsigned int pos = 0;
190 *
191 * while (1) {
192 * buflen = sizeof(buffer);
193 * type = rdsv3_message_next_extension(hdr, &pos, buffer, &buflen);
194 * if (type == RDSV3_EXTHDR_NONE)
195 * break;
196 * ...
197 * }
198 */
199 int
rdsv3_message_next_extension(struct rdsv3_header * hdr,unsigned int * pos,void * buf,unsigned int * buflen)200 rdsv3_message_next_extension(struct rdsv3_header *hdr,
201 unsigned int *pos, void *buf, unsigned int *buflen)
202 {
203 unsigned int offset, ext_type, ext_len;
204 uint8_t *src = hdr->h_exthdr;
205
206 RDSV3_DPRINTF4("rdsv3_message_next_extension", "Enter");
207
208 offset = *pos;
209 if (offset >= RDSV3_HEADER_EXT_SPACE)
210 goto none;
211
212 /*
213 * Get the extension type and length. For now, the
214 * length is implied by the extension type.
215 */
216 ext_type = src[offset++];
217
218 if (ext_type == RDSV3_EXTHDR_NONE || ext_type >= __RDSV3_EXTHDR_MAX)
219 goto none;
220 ext_len = rdsv3_exthdr_size[ext_type];
221 if (offset + ext_len > RDSV3_HEADER_EXT_SPACE)
222 goto none;
223
224 *pos = offset + ext_len;
225 if (ext_len < *buflen)
226 *buflen = ext_len;
227 (void) memcpy(buf, src + offset, *buflen);
228 return (ext_type);
229
230 none:
231 *pos = RDSV3_HEADER_EXT_SPACE;
232 *buflen = 0;
233 return (RDSV3_EXTHDR_NONE);
234 }
235
236 int
rdsv3_message_add_version_extension(struct rdsv3_header * hdr,unsigned int version)237 rdsv3_message_add_version_extension(struct rdsv3_header *hdr,
238 unsigned int version)
239 {
240 struct rdsv3_ext_header_version ext_hdr;
241
242 ext_hdr.h_version = htonl(version);
243 return (rdsv3_message_add_extension(hdr, RDSV3_EXTHDR_VERSION,
244 &ext_hdr, sizeof (ext_hdr)));
245 }
246
247 int
rdsv3_message_get_version_extension(struct rdsv3_header * hdr,unsigned int * version)248 rdsv3_message_get_version_extension(struct rdsv3_header *hdr,
249 unsigned int *version)
250 {
251 struct rdsv3_ext_header_version ext_hdr;
252 unsigned int pos = 0, len = sizeof (ext_hdr);
253
254 RDSV3_DPRINTF4("rdsv3_message_get_version_extension", "Enter");
255
256 /*
257 * We assume the version extension is the only one present
258 */
259 if (rdsv3_message_next_extension(hdr, &pos, &ext_hdr, &len) !=
260 RDSV3_EXTHDR_VERSION)
261 return (0);
262 *version = ntohl(ext_hdr.h_version);
263 return (1);
264 }
265
266 int
rdsv3_message_add_rdma_dest_extension(struct rdsv3_header * hdr,uint32_t r_key,uint32_t offset)267 rdsv3_message_add_rdma_dest_extension(struct rdsv3_header *hdr, uint32_t r_key,
268 uint32_t offset)
269 {
270 struct rdsv3_ext_header_rdma_dest ext_hdr;
271
272 ext_hdr.h_rdma_rkey = htonl(r_key);
273 ext_hdr.h_rdma_offset = htonl(offset);
274 return (rdsv3_message_add_extension(hdr, RDSV3_EXTHDR_RDMA_DEST,
275 &ext_hdr, sizeof (ext_hdr)));
276 }
277
278 struct rdsv3_message *
rdsv3_message_alloc(unsigned int nents,int gfp)279 rdsv3_message_alloc(unsigned int nents, int gfp)
280 {
281 struct rdsv3_message *rm;
282
283 RDSV3_DPRINTF4("rdsv3_message_alloc", "Enter(nents: %d)", nents);
284
285 rm = kmem_zalloc(sizeof (struct rdsv3_message) +
286 (nents * sizeof (struct rdsv3_scatterlist)), gfp);
287 if (!rm)
288 goto out;
289
290 rm->m_refcount = 1;
291 list_link_init(&rm->m_sock_item);
292 list_link_init(&rm->m_conn_item);
293 mutex_init(&rm->m_rs_lock, NULL, MUTEX_DRIVER, NULL);
294 rdsv3_init_waitqueue(&rm->m_flush_wait);
295
296 RDSV3_DPRINTF4("rdsv3_message_alloc", "Return(rm: %p)", rm);
297 out:
298 return (rm);
299 }
300
301 struct rdsv3_message *
rdsv3_message_map_pages(unsigned long * page_addrs,unsigned int total_len)302 rdsv3_message_map_pages(unsigned long *page_addrs, unsigned int total_len)
303 {
304 struct rdsv3_message *rm;
305 unsigned int i;
306
307 RDSV3_DPRINTF4("rdsv3_message_map_pages", "Enter(len: %d)", total_len);
308
309 #ifndef __lock_lint
310 rm = rdsv3_message_alloc(ceil(total_len, PAGE_SIZE), KM_NOSLEEP);
311 #else
312 rm = NULL;
313 #endif
314 if (rm == NULL)
315 return (ERR_PTR(-ENOMEM));
316
317 set_bit(RDSV3_MSG_PAGEVEC, &rm->m_flags);
318 rm->m_inc.i_hdr.h_len = htonl(total_len);
319 #ifndef __lock_lint
320 rm->m_nents = ceil(total_len, PAGE_SIZE);
321 #else
322 rm->m_nents = 0;
323 #endif
324
325 for (i = 0; i < rm->m_nents; ++i) {
326 rdsv3_sg_set_page(&rm->m_sg[i],
327 page_addrs[i],
328 PAGE_SIZE, 0);
329 }
330
331 return (rm);
332 }
333
334 struct rdsv3_message *
rdsv3_message_copy_from_user(struct uio * uiop,size_t total_len)335 rdsv3_message_copy_from_user(struct uio *uiop,
336 size_t total_len)
337 {
338 struct rdsv3_message *rm;
339 struct rdsv3_scatterlist *sg;
340 int ret;
341
342 RDSV3_DPRINTF4("rdsv3_message_copy_from_user", "Enter: %d", total_len);
343
344 #ifndef __lock_lint
345 rm = rdsv3_message_alloc(ceil(total_len, PAGE_SIZE), KM_NOSLEEP);
346 #else
347 rm = NULL;
348 #endif
349 if (rm == NULL) {
350 ret = -ENOMEM;
351 goto out;
352 }
353
354 rm->m_inc.i_hdr.h_len = htonl(total_len);
355
356 /*
357 * now allocate and copy in the data payload.
358 */
359 sg = rm->m_sg;
360
361 while (total_len) {
362 if (rdsv3_sg_page(sg) == NULL) {
363 ret = rdsv3_page_remainder_alloc(sg, total_len, 0);
364 if (ret)
365 goto out;
366 rm->m_nents++;
367 }
368
369 ret = uiomove(rdsv3_sg_page(sg), rdsv3_sg_len(sg), UIO_WRITE,
370 uiop);
371 if (ret) {
372 RDSV3_DPRINTF2("rdsv3_message_copy_from_user",
373 "uiomove failed");
374 ret = -ret;
375 goto out;
376 }
377
378 total_len -= rdsv3_sg_len(sg);
379 sg++;
380 }
381 ret = 0;
382 out:
383 if (ret) {
384 if (rm)
385 rdsv3_message_put(rm);
386 rm = ERR_PTR(ret);
387 }
388 return (rm);
389 }
390
391 int
rdsv3_message_inc_copy_to_user(struct rdsv3_incoming * inc,uio_t * uiop,size_t size)392 rdsv3_message_inc_copy_to_user(struct rdsv3_incoming *inc,
393 uio_t *uiop, size_t size)
394 {
395 struct rdsv3_message *rm;
396 struct rdsv3_scatterlist *sg;
397 unsigned long to_copy;
398 unsigned long vec_off;
399 int copied;
400 int ret;
401 uint32_t len;
402
403 rm = __containerof(inc, struct rdsv3_message, m_inc);
404 len = ntohl(rm->m_inc.i_hdr.h_len);
405
406 RDSV3_DPRINTF4("rdsv3_message_inc_copy_to_user",
407 "Enter(rm: %p, len: %d)", rm, len);
408
409 sg = rm->m_sg;
410 vec_off = 0;
411 copied = 0;
412
413 while (copied < size && copied < len) {
414
415 to_copy = min(len - copied, sg->length - vec_off);
416 to_copy = min(size - copied, to_copy);
417
418 RDSV3_DPRINTF5("rdsv3_message_inc_copy_to_user",
419 "copying %lu bytes to user iov %p from sg [%p, %u] + %lu\n",
420 to_copy, uiop,
421 rdsv3_sg_page(sg), sg->length, vec_off);
422
423 ret = uiomove(rdsv3_sg_page(sg), to_copy, UIO_READ, uiop);
424 if (ret)
425 break;
426
427 vec_off += to_copy;
428 copied += to_copy;
429
430 if (vec_off == sg->length) {
431 vec_off = 0;
432 sg++;
433 }
434 }
435
436 return (copied);
437 }
438
439 /*
440 * If the message is still on the send queue, wait until the transport
441 * is done with it. This is particularly important for RDMA operations.
442 */
443 /* ARGSUSED */
444 void
rdsv3_message_wait(struct rdsv3_message * rm)445 rdsv3_message_wait(struct rdsv3_message *rm)
446 {
447 rdsv3_wait_event(&rm->m_flush_wait,
448 !test_bit(RDSV3_MSG_MAPPED, &rm->m_flags));
449 }
450
451 void
rdsv3_message_unmapped(struct rdsv3_message * rm)452 rdsv3_message_unmapped(struct rdsv3_message *rm)
453 {
454 clear_bit(RDSV3_MSG_MAPPED, &rm->m_flags);
455 rdsv3_wake_up_all(&rm->m_flush_wait);
456 }
457