1 // SPDX-License-Identifier: GPL-2.0-or-later 2 /* 3 * Copyright (C) 2016 Namjae Jeon <linkinjeon@kernel.org> 4 * Copyright (C) 2018 Samsung Electronics Co., Ltd. 5 */ 6 7 #include <linux/freezer.h> 8 9 #include "smb_common.h" 10 #include "server.h" 11 #include "auth.h" 12 #include "connection.h" 13 #include "transport_tcp.h" 14 15 #define IFACE_STATE_DOWN BIT(0) 16 #define IFACE_STATE_CONFIGURED BIT(1) 17 18 static atomic_t active_num_conn; 19 20 struct interface { 21 struct task_struct *ksmbd_kthread; 22 struct socket *ksmbd_socket; 23 struct list_head entry; 24 char *name; 25 struct mutex sock_release_lock; 26 int state; 27 }; 28 29 static LIST_HEAD(iface_list); 30 31 static int bind_additional_ifaces; 32 33 struct tcp_transport { 34 struct ksmbd_transport transport; 35 struct socket *sock; 36 struct kvec *iov; 37 unsigned int nr_iov; 38 }; 39 40 static const struct ksmbd_transport_ops ksmbd_tcp_transport_ops; 41 42 static void tcp_stop_kthread(struct task_struct *kthread); 43 static struct interface *alloc_iface(char *ifname); 44 45 #define KSMBD_TRANS(t) (&(t)->transport) 46 #define TCP_TRANS(t) ((struct tcp_transport *)container_of(t, \ 47 struct tcp_transport, transport)) 48 49 static inline void ksmbd_tcp_nodelay(struct socket *sock) 50 { 51 tcp_sock_set_nodelay(sock->sk); 52 } 53 54 static inline void ksmbd_tcp_reuseaddr(struct socket *sock) 55 { 56 sock_set_reuseaddr(sock->sk); 57 } 58 59 static inline void ksmbd_tcp_rcv_timeout(struct socket *sock, s64 secs) 60 { 61 if (secs && secs < MAX_SCHEDULE_TIMEOUT / HZ - 1) 62 WRITE_ONCE(sock->sk->sk_rcvtimeo, secs * HZ); 63 else 64 WRITE_ONCE(sock->sk->sk_rcvtimeo, MAX_SCHEDULE_TIMEOUT); 65 } 66 67 static inline void ksmbd_tcp_snd_timeout(struct socket *sock, s64 secs) 68 { 69 sock_set_sndtimeo(sock->sk, secs); 70 } 71 72 static struct tcp_transport *alloc_transport(struct socket *client_sk) 73 { 74 struct tcp_transport *t; 75 struct ksmbd_conn *conn; 76 77 t = kzalloc(sizeof(*t), KSMBD_DEFAULT_GFP); 78 if (!t) 79 return NULL; 80 t->sock = client_sk; 81 82 conn = ksmbd_conn_alloc(); 83 if (!conn) { 84 kfree(t); 85 return NULL; 86 } 87 88 conn->transport = KSMBD_TRANS(t); 89 KSMBD_TRANS(t)->conn = conn; 90 KSMBD_TRANS(t)->ops = &ksmbd_tcp_transport_ops; 91 return t; 92 } 93 94 static void ksmbd_tcp_free_transport(struct ksmbd_transport *kt) 95 { 96 struct tcp_transport *t = TCP_TRANS(kt); 97 98 sock_release(t->sock); 99 kfree(t->iov); 100 kfree(t); 101 } 102 103 static void free_transport(struct tcp_transport *t) 104 { 105 kernel_sock_shutdown(t->sock, SHUT_RDWR); 106 ksmbd_conn_free(KSMBD_TRANS(t)->conn); 107 } 108 109 /** 110 * kvec_array_init() - initialize a IO vector segment 111 * @new: IO vector to be initialized 112 * @iov: base IO vector 113 * @nr_segs: number of segments in base iov 114 * @bytes: total iovec length so far for read 115 * 116 * Return: Number of IO segments 117 */ 118 static unsigned int kvec_array_init(struct kvec *new, struct kvec *iov, 119 unsigned int nr_segs, size_t bytes) 120 { 121 size_t base = 0; 122 123 while (bytes || !iov->iov_len) { 124 int copy = min(bytes, iov->iov_len); 125 126 bytes -= copy; 127 base += copy; 128 if (iov->iov_len == base) { 129 iov++; 130 nr_segs--; 131 base = 0; 132 } 133 } 134 135 memcpy(new, iov, sizeof(*iov) * nr_segs); 136 new->iov_base += base; 137 new->iov_len -= base; 138 return nr_segs; 139 } 140 141 /** 142 * get_conn_iovec() - get connection iovec for reading from socket 143 * @t: TCP transport instance 144 * @nr_segs: number of segments in iov 145 * 146 * Return: return existing or newly allocate iovec 147 */ 148 static struct kvec *get_conn_iovec(struct tcp_transport *t, unsigned int nr_segs) 149 { 150 struct kvec *new_iov; 151 152 if (t->iov && nr_segs <= t->nr_iov) 153 return t->iov; 154 155 /* not big enough -- allocate a new one and release the old */ 156 new_iov = kmalloc_array(nr_segs, sizeof(*new_iov), KSMBD_DEFAULT_GFP); 157 if (new_iov) { 158 kfree(t->iov); 159 t->iov = new_iov; 160 t->nr_iov = nr_segs; 161 } 162 return new_iov; 163 } 164 165 static unsigned short ksmbd_tcp_get_port(const struct sockaddr *sa) 166 { 167 switch (sa->sa_family) { 168 case AF_INET: 169 return ntohs(((struct sockaddr_in *)sa)->sin_port); 170 case AF_INET6: 171 return ntohs(((struct sockaddr_in6 *)sa)->sin6_port); 172 } 173 return 0; 174 } 175 176 /** 177 * ksmbd_tcp_new_connection() - create a new tcp session on mount 178 * @client_sk: socket associated with new connection 179 * 180 * whenever a new connection is requested, create a conn thread 181 * (session thread) to handle new incoming smb requests from the connection 182 * 183 * Return: 0 on success, otherwise error 184 */ 185 static int ksmbd_tcp_new_connection(struct socket *client_sk) 186 { 187 struct sockaddr *csin; 188 int rc = 0; 189 struct tcp_transport *t; 190 struct task_struct *handler; 191 192 t = alloc_transport(client_sk); 193 if (!t) { 194 sock_release(client_sk); 195 return -ENOMEM; 196 } 197 198 csin = KSMBD_TCP_PEER_SOCKADDR(KSMBD_TRANS(t)->conn); 199 if (kernel_getpeername(client_sk, csin) < 0) { 200 pr_err("client ip resolution failed\n"); 201 rc = -EINVAL; 202 goto out_error; 203 } 204 205 handler = kthread_run(ksmbd_conn_handler_loop, 206 KSMBD_TRANS(t)->conn, 207 "ksmbd:%u", 208 ksmbd_tcp_get_port(csin)); 209 if (IS_ERR(handler)) { 210 pr_err("cannot start conn thread\n"); 211 rc = PTR_ERR(handler); 212 free_transport(t); 213 } 214 return rc; 215 216 out_error: 217 free_transport(t); 218 return rc; 219 } 220 221 /** 222 * ksmbd_kthread_fn() - listen to new SMB connections and callback server 223 * @p: arguments to forker thread 224 * 225 * Return: 0 on success, error number otherwise 226 */ 227 static int ksmbd_kthread_fn(void *p) 228 { 229 struct socket *client_sk = NULL; 230 struct interface *iface = (struct interface *)p; 231 int ret; 232 233 while (!kthread_should_stop()) { 234 mutex_lock(&iface->sock_release_lock); 235 if (!iface->ksmbd_socket) { 236 mutex_unlock(&iface->sock_release_lock); 237 break; 238 } 239 ret = kernel_accept(iface->ksmbd_socket, &client_sk, 240 SOCK_NONBLOCK); 241 mutex_unlock(&iface->sock_release_lock); 242 if (ret) { 243 if (ret == -EAGAIN) 244 /* check for new connections every 100 msecs */ 245 schedule_timeout_interruptible(HZ / 10); 246 continue; 247 } 248 249 if (server_conf.max_connections && 250 atomic_inc_return(&active_num_conn) >= server_conf.max_connections) { 251 pr_info_ratelimited("Limit the maximum number of connections(%u)\n", 252 atomic_read(&active_num_conn)); 253 atomic_dec(&active_num_conn); 254 sock_release(client_sk); 255 continue; 256 } 257 258 ksmbd_debug(CONN, "connect success: accepted new connection\n"); 259 client_sk->sk->sk_rcvtimeo = KSMBD_TCP_RECV_TIMEOUT; 260 client_sk->sk->sk_sndtimeo = KSMBD_TCP_SEND_TIMEOUT; 261 262 ksmbd_tcp_new_connection(client_sk); 263 } 264 265 ksmbd_debug(CONN, "releasing socket\n"); 266 return 0; 267 } 268 269 /** 270 * ksmbd_tcp_run_kthread() - start forker thread 271 * @iface: pointer to struct interface 272 * 273 * start forker thread(ksmbd/0) at module init time to listen 274 * on port 445 for new SMB connection requests. It creates per connection 275 * server threads(ksmbd/x) 276 * 277 * Return: 0 on success or error number 278 */ 279 static int ksmbd_tcp_run_kthread(struct interface *iface) 280 { 281 int rc; 282 struct task_struct *kthread; 283 284 kthread = kthread_run(ksmbd_kthread_fn, (void *)iface, "ksmbd-%s", 285 iface->name); 286 if (IS_ERR(kthread)) { 287 rc = PTR_ERR(kthread); 288 return rc; 289 } 290 iface->ksmbd_kthread = kthread; 291 292 return 0; 293 } 294 295 /** 296 * ksmbd_tcp_readv() - read data from socket in given iovec 297 * @t: TCP transport instance 298 * @iov_orig: base IO vector 299 * @nr_segs: number of segments in base iov 300 * @to_read: number of bytes to read from socket 301 * @max_retries: maximum retry count 302 * 303 * Return: on success return number of bytes read from socket, 304 * otherwise return error number 305 */ 306 static int ksmbd_tcp_readv(struct tcp_transport *t, struct kvec *iov_orig, 307 unsigned int nr_segs, unsigned int to_read, 308 int max_retries) 309 { 310 int length = 0; 311 int total_read; 312 unsigned int segs; 313 struct msghdr ksmbd_msg; 314 struct kvec *iov; 315 struct ksmbd_conn *conn = KSMBD_TRANS(t)->conn; 316 317 iov = get_conn_iovec(t, nr_segs); 318 if (!iov) 319 return -ENOMEM; 320 321 ksmbd_msg.msg_control = NULL; 322 ksmbd_msg.msg_controllen = 0; 323 324 for (total_read = 0; to_read; total_read += length, to_read -= length) { 325 try_to_freeze(); 326 327 if (!ksmbd_conn_alive(conn)) { 328 total_read = -ESHUTDOWN; 329 break; 330 } 331 segs = kvec_array_init(iov, iov_orig, nr_segs, total_read); 332 333 length = kernel_recvmsg(t->sock, &ksmbd_msg, 334 iov, segs, to_read, 0); 335 336 if (length == -EINTR) { 337 total_read = -ESHUTDOWN; 338 break; 339 } else if (ksmbd_conn_need_reconnect(conn)) { 340 total_read = -EAGAIN; 341 break; 342 } else if (length == -ERESTARTSYS || length == -EAGAIN) { 343 /* 344 * If max_retries is negative, Allow unlimited 345 * retries to keep connection with inactive sessions. 346 */ 347 if (max_retries == 0) { 348 total_read = length; 349 break; 350 } else if (max_retries > 0) { 351 max_retries--; 352 } 353 354 usleep_range(1000, 2000); 355 length = 0; 356 continue; 357 } else if (length <= 0) { 358 total_read = length; 359 break; 360 } 361 } 362 return total_read; 363 } 364 365 /** 366 * ksmbd_tcp_read() - read data from socket in given buffer 367 * @t: TCP transport instance 368 * @buf: buffer to store read data from socket 369 * @to_read: number of bytes to read from socket 370 * @max_retries: number of retries if reading from socket fails 371 * 372 * Return: on success return number of bytes read from socket, 373 * otherwise return error number 374 */ 375 static int ksmbd_tcp_read(struct ksmbd_transport *t, char *buf, 376 unsigned int to_read, int max_retries) 377 { 378 struct kvec iov; 379 380 iov.iov_base = buf; 381 iov.iov_len = to_read; 382 383 return ksmbd_tcp_readv(TCP_TRANS(t), &iov, 1, to_read, max_retries); 384 } 385 386 static int ksmbd_tcp_writev(struct ksmbd_transport *t, struct kvec *iov, 387 int nvecs, int size, bool need_invalidate, 388 unsigned int remote_key) 389 390 { 391 struct msghdr smb_msg = {.msg_flags = MSG_NOSIGNAL}; 392 393 return kernel_sendmsg(TCP_TRANS(t)->sock, &smb_msg, iov, nvecs, size); 394 } 395 396 static void ksmbd_tcp_disconnect(struct ksmbd_transport *t) 397 { 398 free_transport(TCP_TRANS(t)); 399 if (server_conf.max_connections) 400 atomic_dec(&active_num_conn); 401 } 402 403 static void tcp_destroy_socket(struct socket *ksmbd_socket) 404 { 405 int ret; 406 407 if (!ksmbd_socket) 408 return; 409 410 /* set zero to timeout */ 411 ksmbd_tcp_rcv_timeout(ksmbd_socket, 0); 412 ksmbd_tcp_snd_timeout(ksmbd_socket, 0); 413 414 ret = kernel_sock_shutdown(ksmbd_socket, SHUT_RDWR); 415 if (ret) 416 pr_err("Failed to shutdown socket: %d\n", ret); 417 sock_release(ksmbd_socket); 418 } 419 420 /** 421 * create_socket - create socket for ksmbd/0 422 * @iface: interface to bind the created socket to 423 * 424 * Return: 0 on success, error number otherwise 425 */ 426 static int create_socket(struct interface *iface) 427 { 428 int ret; 429 struct sockaddr_in6 sin6; 430 struct sockaddr_in sin; 431 struct socket *ksmbd_socket; 432 bool ipv4 = false; 433 434 ret = sock_create(PF_INET6, SOCK_STREAM, IPPROTO_TCP, &ksmbd_socket); 435 if (ret) { 436 if (ret != -EAFNOSUPPORT) 437 pr_err("Can't create socket for ipv6, fallback to ipv4: %d\n", ret); 438 ret = sock_create(PF_INET, SOCK_STREAM, IPPROTO_TCP, 439 &ksmbd_socket); 440 if (ret) { 441 pr_err("Can't create socket for ipv4: %d\n", ret); 442 goto out_clear; 443 } 444 445 sin.sin_family = PF_INET; 446 sin.sin_addr.s_addr = htonl(INADDR_ANY); 447 sin.sin_port = htons(server_conf.tcp_port); 448 ipv4 = true; 449 } else { 450 sin6.sin6_family = PF_INET6; 451 sin6.sin6_addr = in6addr_any; 452 sin6.sin6_port = htons(server_conf.tcp_port); 453 454 lock_sock(ksmbd_socket->sk); 455 ksmbd_socket->sk->sk_ipv6only = false; 456 release_sock(ksmbd_socket->sk); 457 } 458 459 ksmbd_tcp_nodelay(ksmbd_socket); 460 ksmbd_tcp_reuseaddr(ksmbd_socket); 461 462 ret = sock_setsockopt(ksmbd_socket, 463 SOL_SOCKET, 464 SO_BINDTODEVICE, 465 KERNEL_SOCKPTR(iface->name), 466 strlen(iface->name)); 467 if (ret != -ENODEV && ret < 0) { 468 pr_err("Failed to set SO_BINDTODEVICE: %d\n", ret); 469 goto out_error; 470 } 471 472 if (ipv4) 473 ret = kernel_bind(ksmbd_socket, (struct sockaddr *)&sin, 474 sizeof(sin)); 475 else 476 ret = kernel_bind(ksmbd_socket, (struct sockaddr *)&sin6, 477 sizeof(sin6)); 478 if (ret) { 479 pr_err("Failed to bind socket: %d\n", ret); 480 goto out_error; 481 } 482 483 ksmbd_socket->sk->sk_rcvtimeo = KSMBD_TCP_RECV_TIMEOUT; 484 ksmbd_socket->sk->sk_sndtimeo = KSMBD_TCP_SEND_TIMEOUT; 485 486 ret = kernel_listen(ksmbd_socket, KSMBD_SOCKET_BACKLOG); 487 if (ret) { 488 pr_err("Port listen() error: %d\n", ret); 489 goto out_error; 490 } 491 492 iface->ksmbd_socket = ksmbd_socket; 493 ret = ksmbd_tcp_run_kthread(iface); 494 if (ret) { 495 pr_err("Can't start ksmbd main kthread: %d\n", ret); 496 goto out_error; 497 } 498 iface->state = IFACE_STATE_CONFIGURED; 499 500 return 0; 501 502 out_error: 503 tcp_destroy_socket(ksmbd_socket); 504 out_clear: 505 iface->ksmbd_socket = NULL; 506 return ret; 507 } 508 509 struct interface *ksmbd_find_netdev_name_iface_list(char *netdev_name) 510 { 511 struct interface *iface; 512 513 list_for_each_entry(iface, &iface_list, entry) 514 if (!strcmp(iface->name, netdev_name)) 515 return iface; 516 return NULL; 517 } 518 519 static int ksmbd_netdev_event(struct notifier_block *nb, unsigned long event, 520 void *ptr) 521 { 522 struct net_device *netdev = netdev_notifier_info_to_dev(ptr); 523 struct interface *iface; 524 int ret; 525 526 switch (event) { 527 case NETDEV_UP: 528 if (netif_is_bridge_port(netdev)) 529 return NOTIFY_OK; 530 531 iface = ksmbd_find_netdev_name_iface_list(netdev->name); 532 if (iface && iface->state == IFACE_STATE_DOWN) { 533 ksmbd_debug(CONN, "netdev-up event: netdev(%s) is going up\n", 534 iface->name); 535 ret = create_socket(iface); 536 if (ret) 537 return NOTIFY_OK; 538 } 539 if (!iface && bind_additional_ifaces) { 540 iface = alloc_iface(kstrdup(netdev->name, KSMBD_DEFAULT_GFP)); 541 if (!iface) 542 return NOTIFY_OK; 543 ksmbd_debug(CONN, "netdev-up event: netdev(%s) is going up\n", 544 iface->name); 545 ret = create_socket(iface); 546 if (ret) 547 break; 548 } 549 break; 550 case NETDEV_DOWN: 551 iface = ksmbd_find_netdev_name_iface_list(netdev->name); 552 if (iface && iface->state == IFACE_STATE_CONFIGURED) { 553 ksmbd_debug(CONN, "netdev-down event: netdev(%s) is going down\n", 554 iface->name); 555 tcp_stop_kthread(iface->ksmbd_kthread); 556 iface->ksmbd_kthread = NULL; 557 mutex_lock(&iface->sock_release_lock); 558 tcp_destroy_socket(iface->ksmbd_socket); 559 iface->ksmbd_socket = NULL; 560 mutex_unlock(&iface->sock_release_lock); 561 562 iface->state = IFACE_STATE_DOWN; 563 break; 564 } 565 break; 566 } 567 568 return NOTIFY_DONE; 569 } 570 571 static struct notifier_block ksmbd_netdev_notifier = { 572 .notifier_call = ksmbd_netdev_event, 573 }; 574 575 int ksmbd_tcp_init(void) 576 { 577 register_netdevice_notifier(&ksmbd_netdev_notifier); 578 579 return 0; 580 } 581 582 static void tcp_stop_kthread(struct task_struct *kthread) 583 { 584 int ret; 585 586 if (!kthread) 587 return; 588 589 ret = kthread_stop(kthread); 590 if (ret) 591 pr_err("failed to stop forker thread\n"); 592 } 593 594 void ksmbd_tcp_destroy(void) 595 { 596 struct interface *iface, *tmp; 597 598 unregister_netdevice_notifier(&ksmbd_netdev_notifier); 599 600 list_for_each_entry_safe(iface, tmp, &iface_list, entry) { 601 list_del(&iface->entry); 602 kfree(iface->name); 603 kfree(iface); 604 } 605 } 606 607 static struct interface *alloc_iface(char *ifname) 608 { 609 struct interface *iface; 610 611 if (!ifname) 612 return NULL; 613 614 iface = kzalloc(sizeof(struct interface), KSMBD_DEFAULT_GFP); 615 if (!iface) { 616 kfree(ifname); 617 return NULL; 618 } 619 620 iface->name = ifname; 621 iface->state = IFACE_STATE_DOWN; 622 list_add(&iface->entry, &iface_list); 623 mutex_init(&iface->sock_release_lock); 624 return iface; 625 } 626 627 int ksmbd_tcp_set_interfaces(char *ifc_list, int ifc_list_sz) 628 { 629 int sz = 0; 630 631 if (!ifc_list_sz) { 632 bind_additional_ifaces = 1; 633 return 0; 634 } 635 636 while (ifc_list_sz > 0) { 637 if (!alloc_iface(kstrdup(ifc_list, KSMBD_DEFAULT_GFP))) 638 return -ENOMEM; 639 640 sz = strlen(ifc_list); 641 if (!sz) 642 break; 643 644 ifc_list += sz + 1; 645 ifc_list_sz -= (sz + 1); 646 } 647 648 bind_additional_ifaces = 0; 649 650 return 0; 651 } 652 653 static const struct ksmbd_transport_ops ksmbd_tcp_transport_ops = { 654 .read = ksmbd_tcp_read, 655 .writev = ksmbd_tcp_writev, 656 .disconnect = ksmbd_tcp_disconnect, 657 .free_transport = ksmbd_tcp_free_transport, 658 }; 659