1 // SPDX-License-Identifier: GPL-2.0-or-later 2 /* 3 * Copyright (C) 2017, Microsoft Corporation. 4 * Copyright (C) 2018, LG Electronics. 5 * Copyright (c) 2025, Stefan Metzmacher 6 */ 7 8 #include "internal.h" 9 10 static int smbdirect_connection_wait_for_rw_credits(struct smbdirect_socket *sc, 11 int credits) 12 { 13 return smbdirect_socket_wait_for_credits(sc, 14 SMBDIRECT_SOCKET_CONNECTED, 15 -ENOTCONN, 16 &sc->rw_io.credits.wait_queue, 17 &sc->rw_io.credits.count, 18 credits); 19 } 20 21 static int smbdirect_connection_calc_rw_credits(struct smbdirect_socket *sc, 22 const void *buf, 23 size_t len) 24 { 25 return DIV_ROUND_UP(smbdirect_get_buf_page_count(buf, len), 26 sc->rw_io.credits.num_pages); 27 } 28 29 static int smbdirect_connection_rdma_get_sg_list(void *buf, 30 size_t size, 31 struct scatterlist *sg_list, 32 size_t nentries) 33 { 34 bool high = is_vmalloc_addr(buf); 35 struct page *page; 36 size_t offset, len; 37 int i = 0; 38 39 if (size == 0 || nentries < smbdirect_get_buf_page_count(buf, size)) 40 return -EINVAL; 41 42 offset = offset_in_page(buf); 43 buf -= offset; 44 while (size > 0) { 45 len = min_t(size_t, PAGE_SIZE - offset, size); 46 if (high) 47 page = vmalloc_to_page(buf); 48 else 49 page = kmap_to_page(buf); 50 51 if (!sg_list) 52 return -EINVAL; 53 sg_set_page(sg_list, page, len, offset); 54 sg_list = sg_next(sg_list); 55 56 buf += PAGE_SIZE; 57 size -= len; 58 offset = 0; 59 i++; 60 } 61 62 return i; 63 } 64 65 static void smbdirect_connection_rw_io_free(struct smbdirect_rw_io *msg, 66 enum dma_data_direction dir) 67 { 68 struct smbdirect_socket *sc = msg->socket; 69 70 rdma_rw_ctx_destroy(&msg->rdma_ctx, 71 sc->ib.qp, 72 sc->ib.qp->port, 73 msg->sgt.sgl, 74 msg->sgt.nents, 75 dir); 76 sg_free_table_chained(&msg->sgt, SG_CHUNK_SIZE); 77 kfree(msg); 78 } 79 80 static void smbdirect_connection_rdma_rw_done(struct ib_cq *cq, struct ib_wc *wc, 81 enum dma_data_direction dir) 82 { 83 struct smbdirect_rw_io *msg = 84 container_of(wc->wr_cqe, struct smbdirect_rw_io, cqe); 85 struct smbdirect_socket *sc = msg->socket; 86 87 if (wc->status != IB_WC_SUCCESS) { 88 msg->error = -EIO; 89 pr_err("read/write error. opcode = %d, status = %s(%d)\n", 90 wc->opcode, ib_wc_status_msg(wc->status), wc->status); 91 if (wc->status != IB_WC_WR_FLUSH_ERR) 92 smbdirect_socket_schedule_cleanup(sc, msg->error); 93 } 94 95 complete(msg->completion); 96 } 97 98 static void smbdirect_connection_rdma_read_done(struct ib_cq *cq, struct ib_wc *wc) 99 { 100 smbdirect_connection_rdma_rw_done(cq, wc, DMA_FROM_DEVICE); 101 } 102 103 static void smbdirect_connection_rdma_write_done(struct ib_cq *cq, struct ib_wc *wc) 104 { 105 smbdirect_connection_rdma_rw_done(cq, wc, DMA_TO_DEVICE); 106 } 107 108 int smbdirect_connection_rdma_xmit(struct smbdirect_socket *sc, 109 void *buf, size_t buf_len, 110 struct smbdirect_buffer_descriptor_v1 *desc, 111 size_t desc_len, 112 bool is_read) 113 { 114 const struct smbdirect_socket_parameters *sp = &sc->parameters; 115 enum dma_data_direction direction = is_read ? DMA_FROM_DEVICE : DMA_TO_DEVICE; 116 struct smbdirect_rw_io *msg, *next_msg; 117 size_t i; 118 int ret; 119 DECLARE_COMPLETION_ONSTACK(completion); 120 struct ib_send_wr *first_wr; 121 LIST_HEAD(msg_list); 122 u8 *desc_buf; 123 int credits_needed; 124 size_t desc_buf_len, desc_num = 0; 125 126 if (sc->status != SMBDIRECT_SOCKET_CONNECTED) 127 return -ENOTCONN; 128 129 if (buf_len > sp->max_read_write_size) 130 return -EINVAL; 131 132 /* calculate needed credits */ 133 credits_needed = 0; 134 desc_buf = buf; 135 for (i = 0; i < desc_len / sizeof(*desc); i++) { 136 if (!buf_len) 137 break; 138 139 desc_buf_len = le32_to_cpu(desc[i].length); 140 if (!desc_buf_len) 141 return -EINVAL; 142 143 if (desc_buf_len > buf_len) { 144 desc_buf_len = buf_len; 145 desc[i].length = cpu_to_le32(desc_buf_len); 146 buf_len = 0; 147 } 148 149 credits_needed += smbdirect_connection_calc_rw_credits(sc, 150 desc_buf, 151 desc_buf_len); 152 desc_buf += desc_buf_len; 153 buf_len -= desc_buf_len; 154 desc_num++; 155 } 156 157 smbdirect_log_rdma_rw(sc, SMBDIRECT_LOG_INFO, 158 "RDMA %s, len %zu, needed credits %d\n", 159 str_read_write(is_read), buf_len, credits_needed); 160 161 ret = smbdirect_connection_wait_for_rw_credits(sc, credits_needed); 162 if (ret < 0) 163 return ret; 164 165 /* build rdma_rw_ctx for each descriptor */ 166 desc_buf = buf; 167 for (i = 0; i < desc_num; i++) { 168 size_t page_count; 169 170 msg = kzalloc_flex(*msg, sg_list, SG_CHUNK_SIZE, 171 sc->rw_io.mem.gfp_mask); 172 if (!msg) { 173 ret = -ENOMEM; 174 goto out; 175 } 176 177 desc_buf_len = le32_to_cpu(desc[i].length); 178 page_count = smbdirect_get_buf_page_count(desc_buf, desc_buf_len); 179 180 msg->socket = sc; 181 msg->cqe.done = is_read ? 182 smbdirect_connection_rdma_read_done : 183 smbdirect_connection_rdma_write_done; 184 msg->completion = &completion; 185 186 msg->sgt.sgl = &msg->sg_list[0]; 187 ret = sg_alloc_table_chained(&msg->sgt, 188 page_count, 189 msg->sg_list, 190 SG_CHUNK_SIZE); 191 if (ret) { 192 ret = -ENOMEM; 193 goto free_msg; 194 } 195 196 ret = smbdirect_connection_rdma_get_sg_list(desc_buf, 197 desc_buf_len, 198 msg->sgt.sgl, 199 msg->sgt.orig_nents); 200 if (ret < 0) 201 goto free_table; 202 203 ret = rdma_rw_ctx_init(&msg->rdma_ctx, 204 sc->ib.qp, 205 sc->ib.qp->port, 206 msg->sgt.sgl, 207 page_count, 208 0, 209 le64_to_cpu(desc[i].offset), 210 le32_to_cpu(desc[i].token), 211 direction); 212 if (ret < 0) { 213 pr_err("failed to init rdma_rw_ctx: %d\n", ret); 214 goto free_table; 215 } 216 217 list_add_tail(&msg->list, &msg_list); 218 desc_buf += desc_buf_len; 219 } 220 221 /* concatenate work requests of rdma_rw_ctxs */ 222 first_wr = NULL; 223 list_for_each_entry_reverse(msg, &msg_list, list) { 224 first_wr = rdma_rw_ctx_wrs(&msg->rdma_ctx, 225 sc->ib.qp, 226 sc->ib.qp->port, 227 &msg->cqe, 228 first_wr); 229 } 230 231 ret = ib_post_send(sc->ib.qp, first_wr, NULL); 232 if (ret) { 233 pr_err("failed to post send wr for RDMA R/W: %d\n", ret); 234 goto out; 235 } 236 237 msg = list_last_entry(&msg_list, struct smbdirect_rw_io, list); 238 wait_for_completion(&completion); 239 ret = msg->error; 240 out: 241 list_for_each_entry_safe(msg, next_msg, &msg_list, list) { 242 list_del(&msg->list); 243 smbdirect_connection_rw_io_free(msg, direction); 244 } 245 atomic_add(credits_needed, &sc->rw_io.credits.count); 246 wake_up(&sc->rw_io.credits.wait_queue); 247 return ret; 248 249 free_table: 250 sg_free_table_chained(&msg->sgt, SG_CHUNK_SIZE); 251 free_msg: 252 kfree(msg); 253 goto out; 254 } 255 __SMBDIRECT_EXPORT_SYMBOL__(smbdirect_connection_rdma_xmit); 256