xref: /linux/fs/smb/smbdirect/rw.c (revision 0fc8f6200d2313278fbf4539bbab74677c685531)
1 // SPDX-License-Identifier: GPL-2.0-or-later
2 /*
3  *   Copyright (C) 2017, Microsoft Corporation.
4  *   Copyright (C) 2018, LG Electronics.
5  *   Copyright (c) 2025, Stefan Metzmacher
6  */
7 
8 #include "internal.h"
9 
10 static int smbdirect_connection_wait_for_rw_credits(struct smbdirect_socket *sc,
11 						    int credits)
12 {
13 	return smbdirect_socket_wait_for_credits(sc,
14 						 SMBDIRECT_SOCKET_CONNECTED,
15 						 -ENOTCONN,
16 						 &sc->rw_io.credits.wait_queue,
17 						 &sc->rw_io.credits.count,
18 						 credits);
19 }
20 
21 static int smbdirect_connection_calc_rw_credits(struct smbdirect_socket *sc,
22 						const void *buf,
23 						size_t len)
24 {
25 	return DIV_ROUND_UP(smbdirect_get_buf_page_count(buf, len),
26 			    sc->rw_io.credits.num_pages);
27 }
28 
29 static int smbdirect_connection_rdma_get_sg_list(void *buf,
30 						 size_t size,
31 						 struct scatterlist *sg_list,
32 						 size_t nentries)
33 {
34 	bool high = is_vmalloc_addr(buf);
35 	struct page *page;
36 	size_t offset, len;
37 	int i = 0;
38 
39 	if (size == 0 || nentries < smbdirect_get_buf_page_count(buf, size))
40 		return -EINVAL;
41 
42 	offset = offset_in_page(buf);
43 	buf -= offset;
44 	while (size > 0) {
45 		len = min_t(size_t, PAGE_SIZE - offset, size);
46 		if (high)
47 			page = vmalloc_to_page(buf);
48 		else
49 			page = kmap_to_page(buf);
50 
51 		if (!sg_list)
52 			return -EINVAL;
53 		sg_set_page(sg_list, page, len, offset);
54 		sg_list = sg_next(sg_list);
55 
56 		buf += PAGE_SIZE;
57 		size -= len;
58 		offset = 0;
59 		i++;
60 	}
61 
62 	return i;
63 }
64 
65 static void smbdirect_connection_rw_io_free(struct smbdirect_rw_io *msg,
66 					    enum dma_data_direction dir)
67 {
68 	struct smbdirect_socket *sc = msg->socket;
69 
70 	rdma_rw_ctx_destroy(&msg->rdma_ctx,
71 			    sc->ib.qp,
72 			    sc->ib.qp->port,
73 			    msg->sgt.sgl,
74 			    msg->sgt.nents,
75 			    dir);
76 	sg_free_table_chained(&msg->sgt, SG_CHUNK_SIZE);
77 	kfree(msg);
78 }
79 
80 static void smbdirect_connection_rdma_rw_done(struct ib_cq *cq, struct ib_wc *wc,
81 					      enum dma_data_direction dir)
82 {
83 	struct smbdirect_rw_io *msg =
84 		container_of(wc->wr_cqe, struct smbdirect_rw_io, cqe);
85 	struct smbdirect_socket *sc = msg->socket;
86 
87 	if (wc->status != IB_WC_SUCCESS) {
88 		msg->error = -EIO;
89 		pr_err("read/write error. opcode = %d, status = %s(%d)\n",
90 		       wc->opcode, ib_wc_status_msg(wc->status), wc->status);
91 		if (wc->status != IB_WC_WR_FLUSH_ERR)
92 			smbdirect_socket_schedule_cleanup(sc, msg->error);
93 	}
94 
95 	complete(msg->completion);
96 }
97 
98 static void smbdirect_connection_rdma_read_done(struct ib_cq *cq, struct ib_wc *wc)
99 {
100 	smbdirect_connection_rdma_rw_done(cq, wc, DMA_FROM_DEVICE);
101 }
102 
103 static void smbdirect_connection_rdma_write_done(struct ib_cq *cq, struct ib_wc *wc)
104 {
105 	smbdirect_connection_rdma_rw_done(cq, wc, DMA_TO_DEVICE);
106 }
107 
108 int smbdirect_connection_rdma_xmit(struct smbdirect_socket *sc,
109 				   void *buf, size_t buf_len,
110 				   struct smbdirect_buffer_descriptor_v1 *desc,
111 				   size_t desc_len,
112 				   bool is_read)
113 {
114 	const struct smbdirect_socket_parameters *sp = &sc->parameters;
115 	enum dma_data_direction direction = is_read ? DMA_FROM_DEVICE : DMA_TO_DEVICE;
116 	struct smbdirect_rw_io *msg, *next_msg;
117 	size_t i;
118 	int ret;
119 	DECLARE_COMPLETION_ONSTACK(completion);
120 	struct ib_send_wr *first_wr;
121 	LIST_HEAD(msg_list);
122 	u8 *desc_buf;
123 	int credits_needed;
124 	size_t desc_buf_len, desc_num = 0;
125 
126 	if (sc->status != SMBDIRECT_SOCKET_CONNECTED)
127 		return -ENOTCONN;
128 
129 	if (buf_len > sp->max_read_write_size)
130 		return -EINVAL;
131 
132 	/* calculate needed credits */
133 	credits_needed = 0;
134 	desc_buf = buf;
135 	for (i = 0; i < desc_len / sizeof(*desc); i++) {
136 		if (!buf_len)
137 			break;
138 
139 		desc_buf_len = le32_to_cpu(desc[i].length);
140 		if (!desc_buf_len)
141 			return -EINVAL;
142 
143 		if (desc_buf_len > buf_len) {
144 			desc_buf_len = buf_len;
145 			desc[i].length = cpu_to_le32(desc_buf_len);
146 			buf_len = 0;
147 		}
148 
149 		credits_needed += smbdirect_connection_calc_rw_credits(sc,
150 								       desc_buf,
151 								       desc_buf_len);
152 		desc_buf += desc_buf_len;
153 		buf_len -= desc_buf_len;
154 		desc_num++;
155 	}
156 
157 	smbdirect_log_rdma_rw(sc, SMBDIRECT_LOG_INFO,
158 		"RDMA %s, len %zu, needed credits %d\n",
159 		str_read_write(is_read), buf_len, credits_needed);
160 
161 	ret = smbdirect_connection_wait_for_rw_credits(sc, credits_needed);
162 	if (ret < 0)
163 		return ret;
164 
165 	/* build rdma_rw_ctx for each descriptor */
166 	desc_buf = buf;
167 	for (i = 0; i < desc_num; i++) {
168 		size_t page_count;
169 
170 		msg = kzalloc_flex(*msg, sg_list, SG_CHUNK_SIZE,
171 				   sc->rw_io.mem.gfp_mask);
172 		if (!msg) {
173 			ret = -ENOMEM;
174 			goto out;
175 		}
176 
177 		desc_buf_len = le32_to_cpu(desc[i].length);
178 		page_count = smbdirect_get_buf_page_count(desc_buf, desc_buf_len);
179 
180 		msg->socket = sc;
181 		msg->cqe.done = is_read ?
182 			smbdirect_connection_rdma_read_done :
183 			smbdirect_connection_rdma_write_done;
184 		msg->completion = &completion;
185 
186 		msg->sgt.sgl = &msg->sg_list[0];
187 		ret = sg_alloc_table_chained(&msg->sgt,
188 					     page_count,
189 					     msg->sg_list,
190 					     SG_CHUNK_SIZE);
191 		if (ret) {
192 			ret = -ENOMEM;
193 			goto free_msg;
194 		}
195 
196 		ret = smbdirect_connection_rdma_get_sg_list(desc_buf,
197 							    desc_buf_len,
198 							    msg->sgt.sgl,
199 							    msg->sgt.orig_nents);
200 		if (ret < 0)
201 			goto free_table;
202 
203 		ret = rdma_rw_ctx_init(&msg->rdma_ctx,
204 				       sc->ib.qp,
205 				       sc->ib.qp->port,
206 				       msg->sgt.sgl,
207 				       page_count,
208 				       0,
209 				       le64_to_cpu(desc[i].offset),
210 				       le32_to_cpu(desc[i].token),
211 				       direction);
212 		if (ret < 0) {
213 			pr_err("failed to init rdma_rw_ctx: %d\n", ret);
214 			goto free_table;
215 		}
216 
217 		list_add_tail(&msg->list, &msg_list);
218 		desc_buf += desc_buf_len;
219 	}
220 
221 	/* concatenate work requests of rdma_rw_ctxs */
222 	first_wr = NULL;
223 	list_for_each_entry_reverse(msg, &msg_list, list) {
224 		first_wr = rdma_rw_ctx_wrs(&msg->rdma_ctx,
225 					   sc->ib.qp,
226 					   sc->ib.qp->port,
227 					   &msg->cqe,
228 					   first_wr);
229 	}
230 
231 	ret = ib_post_send(sc->ib.qp, first_wr, NULL);
232 	if (ret) {
233 		pr_err("failed to post send wr for RDMA R/W: %d\n", ret);
234 		goto out;
235 	}
236 
237 	msg = list_last_entry(&msg_list, struct smbdirect_rw_io, list);
238 	wait_for_completion(&completion);
239 	ret = msg->error;
240 out:
241 	list_for_each_entry_safe(msg, next_msg, &msg_list, list) {
242 		list_del(&msg->list);
243 		smbdirect_connection_rw_io_free(msg, direction);
244 	}
245 	atomic_add(credits_needed, &sc->rw_io.credits.count);
246 	wake_up(&sc->rw_io.credits.wait_queue);
247 	return ret;
248 
249 free_table:
250 	sg_free_table_chained(&msg->sgt, SG_CHUNK_SIZE);
251 free_msg:
252 	kfree(msg);
253 	goto out;
254 }
255 __SMBDIRECT_EXPORT_SYMBOL__(smbdirect_connection_rdma_xmit);
256