1 // SPDX-License-Identifier: GPL-2.0-or-later 2 /* 3 * Copyright (C) 2016 Namjae Jeon <linkinjeon@kernel.org> 4 * Copyright (C) 2018 Samsung Electronics Co., Ltd. 5 */ 6 7 #include <linux/freezer.h> 8 9 #include "smb_common.h" 10 #include "server.h" 11 #include "auth.h" 12 #include "connection.h" 13 #include "transport_tcp.h" 14 15 #define IFACE_STATE_DOWN BIT(0) 16 #define IFACE_STATE_CONFIGURED BIT(1) 17 18 static atomic_t active_num_conn; 19 20 struct interface { 21 struct task_struct *ksmbd_kthread; 22 struct socket *ksmbd_socket; 23 struct list_head entry; 24 char *name; 25 struct mutex sock_release_lock; 26 int state; 27 }; 28 29 static LIST_HEAD(iface_list); 30 31 static int bind_additional_ifaces; 32 33 struct tcp_transport { 34 struct ksmbd_transport transport; 35 struct socket *sock; 36 struct kvec *iov; 37 unsigned int nr_iov; 38 }; 39 40 static const struct ksmbd_transport_ops ksmbd_tcp_transport_ops; 41 42 static void tcp_stop_kthread(struct task_struct *kthread); 43 static struct interface *alloc_iface(char *ifname); 44 45 #define KSMBD_TRANS(t) (&(t)->transport) 46 #define TCP_TRANS(t) ((struct tcp_transport *)container_of(t, \ 47 struct tcp_transport, transport)) 48 49 static inline void ksmbd_tcp_nodelay(struct socket *sock) 50 { 51 tcp_sock_set_nodelay(sock->sk); 52 } 53 54 static inline void ksmbd_tcp_reuseaddr(struct socket *sock) 55 { 56 sock_set_reuseaddr(sock->sk); 57 } 58 59 static inline void ksmbd_tcp_rcv_timeout(struct socket *sock, s64 secs) 60 { 61 lock_sock(sock->sk); 62 if (secs && secs < MAX_SCHEDULE_TIMEOUT / HZ - 1) 63 sock->sk->sk_rcvtimeo = secs * HZ; 64 else 65 sock->sk->sk_rcvtimeo = MAX_SCHEDULE_TIMEOUT; 66 release_sock(sock->sk); 67 } 68 69 static inline void ksmbd_tcp_snd_timeout(struct socket *sock, s64 secs) 70 { 71 sock_set_sndtimeo(sock->sk, secs); 72 } 73 74 static struct tcp_transport *alloc_transport(struct socket *client_sk) 75 { 76 struct tcp_transport *t; 77 struct ksmbd_conn *conn; 78 79 t = kzalloc(sizeof(*t), GFP_KERNEL); 80 if (!t) 81 return NULL; 82 t->sock = client_sk; 83 84 conn = ksmbd_conn_alloc(); 85 if (!conn) { 86 kfree(t); 87 return NULL; 88 } 89 90 conn->transport = KSMBD_TRANS(t); 91 KSMBD_TRANS(t)->conn = conn; 92 KSMBD_TRANS(t)->ops = &ksmbd_tcp_transport_ops; 93 return t; 94 } 95 96 static void free_transport(struct tcp_transport *t) 97 { 98 kernel_sock_shutdown(t->sock, SHUT_RDWR); 99 sock_release(t->sock); 100 t->sock = NULL; 101 102 ksmbd_conn_free(KSMBD_TRANS(t)->conn); 103 kfree(t->iov); 104 kfree(t); 105 } 106 107 /** 108 * kvec_array_init() - initialize a IO vector segment 109 * @new: IO vector to be initialized 110 * @iov: base IO vector 111 * @nr_segs: number of segments in base iov 112 * @bytes: total iovec length so far for read 113 * 114 * Return: Number of IO segments 115 */ 116 static unsigned int kvec_array_init(struct kvec *new, struct kvec *iov, 117 unsigned int nr_segs, size_t bytes) 118 { 119 size_t base = 0; 120 121 while (bytes || !iov->iov_len) { 122 int copy = min(bytes, iov->iov_len); 123 124 bytes -= copy; 125 base += copy; 126 if (iov->iov_len == base) { 127 iov++; 128 nr_segs--; 129 base = 0; 130 } 131 } 132 133 memcpy(new, iov, sizeof(*iov) * nr_segs); 134 new->iov_base += base; 135 new->iov_len -= base; 136 return nr_segs; 137 } 138 139 /** 140 * get_conn_iovec() - get connection iovec for reading from socket 141 * @t: TCP transport instance 142 * @nr_segs: number of segments in iov 143 * 144 * Return: return existing or newly allocate iovec 145 */ 146 static struct kvec *get_conn_iovec(struct tcp_transport *t, unsigned int nr_segs) 147 { 148 struct kvec *new_iov; 149 150 if (t->iov && nr_segs <= t->nr_iov) 151 return t->iov; 152 153 /* not big enough -- allocate a new one and release the old */ 154 new_iov = kmalloc_array(nr_segs, sizeof(*new_iov), GFP_KERNEL); 155 if (new_iov) { 156 kfree(t->iov); 157 t->iov = new_iov; 158 t->nr_iov = nr_segs; 159 } 160 return new_iov; 161 } 162 163 static unsigned short ksmbd_tcp_get_port(const struct sockaddr *sa) 164 { 165 switch (sa->sa_family) { 166 case AF_INET: 167 return ntohs(((struct sockaddr_in *)sa)->sin_port); 168 case AF_INET6: 169 return ntohs(((struct sockaddr_in6 *)sa)->sin6_port); 170 } 171 return 0; 172 } 173 174 /** 175 * ksmbd_tcp_new_connection() - create a new tcp session on mount 176 * @client_sk: socket associated with new connection 177 * 178 * whenever a new connection is requested, create a conn thread 179 * (session thread) to handle new incoming smb requests from the connection 180 * 181 * Return: 0 on success, otherwise error 182 */ 183 static int ksmbd_tcp_new_connection(struct socket *client_sk) 184 { 185 struct sockaddr *csin; 186 int rc = 0; 187 struct tcp_transport *t; 188 struct task_struct *handler; 189 190 t = alloc_transport(client_sk); 191 if (!t) { 192 sock_release(client_sk); 193 return -ENOMEM; 194 } 195 196 csin = KSMBD_TCP_PEER_SOCKADDR(KSMBD_TRANS(t)->conn); 197 if (kernel_getpeername(client_sk, csin) < 0) { 198 pr_err("client ip resolution failed\n"); 199 rc = -EINVAL; 200 goto out_error; 201 } 202 203 handler = kthread_run(ksmbd_conn_handler_loop, 204 KSMBD_TRANS(t)->conn, 205 "ksmbd:%u", 206 ksmbd_tcp_get_port(csin)); 207 if (IS_ERR(handler)) { 208 pr_err("cannot start conn thread\n"); 209 rc = PTR_ERR(handler); 210 free_transport(t); 211 } 212 return rc; 213 214 out_error: 215 free_transport(t); 216 return rc; 217 } 218 219 /** 220 * ksmbd_kthread_fn() - listen to new SMB connections and callback server 221 * @p: arguments to forker thread 222 * 223 * Return: 0 on success, error number otherwise 224 */ 225 static int ksmbd_kthread_fn(void *p) 226 { 227 struct socket *client_sk = NULL; 228 struct interface *iface = (struct interface *)p; 229 int ret; 230 231 while (!kthread_should_stop()) { 232 mutex_lock(&iface->sock_release_lock); 233 if (!iface->ksmbd_socket) { 234 mutex_unlock(&iface->sock_release_lock); 235 break; 236 } 237 ret = kernel_accept(iface->ksmbd_socket, &client_sk, 238 SOCK_NONBLOCK); 239 mutex_unlock(&iface->sock_release_lock); 240 if (ret) { 241 if (ret == -EAGAIN) 242 /* check for new connections every 100 msecs */ 243 schedule_timeout_interruptible(HZ / 10); 244 continue; 245 } 246 247 if (server_conf.max_connections && 248 atomic_inc_return(&active_num_conn) >= server_conf.max_connections) { 249 pr_info_ratelimited("Limit the maximum number of connections(%u)\n", 250 atomic_read(&active_num_conn)); 251 atomic_dec(&active_num_conn); 252 sock_release(client_sk); 253 continue; 254 } 255 256 ksmbd_debug(CONN, "connect success: accepted new connection\n"); 257 client_sk->sk->sk_rcvtimeo = KSMBD_TCP_RECV_TIMEOUT; 258 client_sk->sk->sk_sndtimeo = KSMBD_TCP_SEND_TIMEOUT; 259 260 ksmbd_tcp_new_connection(client_sk); 261 } 262 263 ksmbd_debug(CONN, "releasing socket\n"); 264 return 0; 265 } 266 267 /** 268 * ksmbd_tcp_run_kthread() - start forker thread 269 * @iface: pointer to struct interface 270 * 271 * start forker thread(ksmbd/0) at module init time to listen 272 * on port 445 for new SMB connection requests. It creates per connection 273 * server threads(ksmbd/x) 274 * 275 * Return: 0 on success or error number 276 */ 277 static int ksmbd_tcp_run_kthread(struct interface *iface) 278 { 279 int rc; 280 struct task_struct *kthread; 281 282 kthread = kthread_run(ksmbd_kthread_fn, (void *)iface, "ksmbd-%s", 283 iface->name); 284 if (IS_ERR(kthread)) { 285 rc = PTR_ERR(kthread); 286 return rc; 287 } 288 iface->ksmbd_kthread = kthread; 289 290 return 0; 291 } 292 293 /** 294 * ksmbd_tcp_readv() - read data from socket in given iovec 295 * @t: TCP transport instance 296 * @iov_orig: base IO vector 297 * @nr_segs: number of segments in base iov 298 * @to_read: number of bytes to read from socket 299 * @max_retries: maximum retry count 300 * 301 * Return: on success return number of bytes read from socket, 302 * otherwise return error number 303 */ 304 static int ksmbd_tcp_readv(struct tcp_transport *t, struct kvec *iov_orig, 305 unsigned int nr_segs, unsigned int to_read, 306 int max_retries) 307 { 308 int length = 0; 309 int total_read; 310 unsigned int segs; 311 struct msghdr ksmbd_msg; 312 struct kvec *iov; 313 struct ksmbd_conn *conn = KSMBD_TRANS(t)->conn; 314 315 iov = get_conn_iovec(t, nr_segs); 316 if (!iov) 317 return -ENOMEM; 318 319 ksmbd_msg.msg_control = NULL; 320 ksmbd_msg.msg_controllen = 0; 321 322 for (total_read = 0; to_read; total_read += length, to_read -= length) { 323 try_to_freeze(); 324 325 if (!ksmbd_conn_alive(conn)) { 326 total_read = -ESHUTDOWN; 327 break; 328 } 329 segs = kvec_array_init(iov, iov_orig, nr_segs, total_read); 330 331 length = kernel_recvmsg(t->sock, &ksmbd_msg, 332 iov, segs, to_read, 0); 333 334 if (length == -EINTR) { 335 total_read = -ESHUTDOWN; 336 break; 337 } else if (ksmbd_conn_need_reconnect(conn)) { 338 total_read = -EAGAIN; 339 break; 340 } else if (length == -ERESTARTSYS || length == -EAGAIN) { 341 /* 342 * If max_retries is negative, Allow unlimited 343 * retries to keep connection with inactive sessions. 344 */ 345 if (max_retries == 0) { 346 total_read = length; 347 break; 348 } else if (max_retries > 0) { 349 max_retries--; 350 } 351 352 usleep_range(1000, 2000); 353 length = 0; 354 continue; 355 } else if (length <= 0) { 356 total_read = length; 357 break; 358 } 359 } 360 return total_read; 361 } 362 363 /** 364 * ksmbd_tcp_read() - read data from socket in given buffer 365 * @t: TCP transport instance 366 * @buf: buffer to store read data from socket 367 * @to_read: number of bytes to read from socket 368 * @max_retries: number of retries if reading from socket fails 369 * 370 * Return: on success return number of bytes read from socket, 371 * otherwise return error number 372 */ 373 static int ksmbd_tcp_read(struct ksmbd_transport *t, char *buf, 374 unsigned int to_read, int max_retries) 375 { 376 struct kvec iov; 377 378 iov.iov_base = buf; 379 iov.iov_len = to_read; 380 381 return ksmbd_tcp_readv(TCP_TRANS(t), &iov, 1, to_read, max_retries); 382 } 383 384 static int ksmbd_tcp_writev(struct ksmbd_transport *t, struct kvec *iov, 385 int nvecs, int size, bool need_invalidate, 386 unsigned int remote_key) 387 388 { 389 struct msghdr smb_msg = {.msg_flags = MSG_NOSIGNAL}; 390 391 return kernel_sendmsg(TCP_TRANS(t)->sock, &smb_msg, iov, nvecs, size); 392 } 393 394 static void ksmbd_tcp_disconnect(struct ksmbd_transport *t) 395 { 396 free_transport(TCP_TRANS(t)); 397 if (server_conf.max_connections) 398 atomic_dec(&active_num_conn); 399 } 400 401 static void tcp_destroy_socket(struct socket *ksmbd_socket) 402 { 403 int ret; 404 405 if (!ksmbd_socket) 406 return; 407 408 /* set zero to timeout */ 409 ksmbd_tcp_rcv_timeout(ksmbd_socket, 0); 410 ksmbd_tcp_snd_timeout(ksmbd_socket, 0); 411 412 ret = kernel_sock_shutdown(ksmbd_socket, SHUT_RDWR); 413 if (ret) 414 pr_err("Failed to shutdown socket: %d\n", ret); 415 sock_release(ksmbd_socket); 416 } 417 418 /** 419 * create_socket - create socket for ksmbd/0 420 * @iface: interface to bind the created socket to 421 * 422 * Return: 0 on success, error number otherwise 423 */ 424 static int create_socket(struct interface *iface) 425 { 426 int ret; 427 struct sockaddr_in6 sin6; 428 struct sockaddr_in sin; 429 struct socket *ksmbd_socket; 430 bool ipv4 = false; 431 432 ret = sock_create(PF_INET6, SOCK_STREAM, IPPROTO_TCP, &ksmbd_socket); 433 if (ret) { 434 if (ret != -EAFNOSUPPORT) 435 pr_err("Can't create socket for ipv6, fallback to ipv4: %d\n", ret); 436 ret = sock_create(PF_INET, SOCK_STREAM, IPPROTO_TCP, 437 &ksmbd_socket); 438 if (ret) { 439 pr_err("Can't create socket for ipv4: %d\n", ret); 440 goto out_clear; 441 } 442 443 sin.sin_family = PF_INET; 444 sin.sin_addr.s_addr = htonl(INADDR_ANY); 445 sin.sin_port = htons(server_conf.tcp_port); 446 ipv4 = true; 447 } else { 448 sin6.sin6_family = PF_INET6; 449 sin6.sin6_addr = in6addr_any; 450 sin6.sin6_port = htons(server_conf.tcp_port); 451 452 lock_sock(ksmbd_socket->sk); 453 ksmbd_socket->sk->sk_ipv6only = false; 454 release_sock(ksmbd_socket->sk); 455 } 456 457 ksmbd_tcp_nodelay(ksmbd_socket); 458 ksmbd_tcp_reuseaddr(ksmbd_socket); 459 460 ret = sock_setsockopt(ksmbd_socket, 461 SOL_SOCKET, 462 SO_BINDTODEVICE, 463 KERNEL_SOCKPTR(iface->name), 464 strlen(iface->name)); 465 if (ret != -ENODEV && ret < 0) { 466 pr_err("Failed to set SO_BINDTODEVICE: %d\n", ret); 467 goto out_error; 468 } 469 470 if (ipv4) 471 ret = kernel_bind(ksmbd_socket, (struct sockaddr *)&sin, 472 sizeof(sin)); 473 else 474 ret = kernel_bind(ksmbd_socket, (struct sockaddr *)&sin6, 475 sizeof(sin6)); 476 if (ret) { 477 pr_err("Failed to bind socket: %d\n", ret); 478 goto out_error; 479 } 480 481 ksmbd_socket->sk->sk_rcvtimeo = KSMBD_TCP_RECV_TIMEOUT; 482 ksmbd_socket->sk->sk_sndtimeo = KSMBD_TCP_SEND_TIMEOUT; 483 484 ret = kernel_listen(ksmbd_socket, KSMBD_SOCKET_BACKLOG); 485 if (ret) { 486 pr_err("Port listen() error: %d\n", ret); 487 goto out_error; 488 } 489 490 iface->ksmbd_socket = ksmbd_socket; 491 ret = ksmbd_tcp_run_kthread(iface); 492 if (ret) { 493 pr_err("Can't start ksmbd main kthread: %d\n", ret); 494 goto out_error; 495 } 496 iface->state = IFACE_STATE_CONFIGURED; 497 498 return 0; 499 500 out_error: 501 tcp_destroy_socket(ksmbd_socket); 502 out_clear: 503 iface->ksmbd_socket = NULL; 504 return ret; 505 } 506 507 static int ksmbd_netdev_event(struct notifier_block *nb, unsigned long event, 508 void *ptr) 509 { 510 struct net_device *netdev = netdev_notifier_info_to_dev(ptr); 511 struct interface *iface; 512 int ret, found = 0; 513 514 switch (event) { 515 case NETDEV_UP: 516 if (netif_is_bridge_port(netdev)) 517 return NOTIFY_OK; 518 519 list_for_each_entry(iface, &iface_list, entry) { 520 if (!strcmp(iface->name, netdev->name)) { 521 found = 1; 522 if (iface->state != IFACE_STATE_DOWN) 523 break; 524 ret = create_socket(iface); 525 if (ret) 526 return NOTIFY_OK; 527 break; 528 } 529 } 530 if (!found && bind_additional_ifaces) { 531 iface = alloc_iface(kstrdup(netdev->name, GFP_KERNEL)); 532 if (!iface) 533 return NOTIFY_OK; 534 ret = create_socket(iface); 535 if (ret) 536 break; 537 } 538 break; 539 case NETDEV_DOWN: 540 list_for_each_entry(iface, &iface_list, entry) { 541 if (!strcmp(iface->name, netdev->name) && 542 iface->state == IFACE_STATE_CONFIGURED) { 543 tcp_stop_kthread(iface->ksmbd_kthread); 544 iface->ksmbd_kthread = NULL; 545 mutex_lock(&iface->sock_release_lock); 546 tcp_destroy_socket(iface->ksmbd_socket); 547 iface->ksmbd_socket = NULL; 548 mutex_unlock(&iface->sock_release_lock); 549 550 iface->state = IFACE_STATE_DOWN; 551 break; 552 } 553 } 554 break; 555 } 556 557 return NOTIFY_DONE; 558 } 559 560 static struct notifier_block ksmbd_netdev_notifier = { 561 .notifier_call = ksmbd_netdev_event, 562 }; 563 564 int ksmbd_tcp_init(void) 565 { 566 register_netdevice_notifier(&ksmbd_netdev_notifier); 567 568 return 0; 569 } 570 571 static void tcp_stop_kthread(struct task_struct *kthread) 572 { 573 int ret; 574 575 if (!kthread) 576 return; 577 578 ret = kthread_stop(kthread); 579 if (ret) 580 pr_err("failed to stop forker thread\n"); 581 } 582 583 void ksmbd_tcp_destroy(void) 584 { 585 struct interface *iface, *tmp; 586 587 unregister_netdevice_notifier(&ksmbd_netdev_notifier); 588 589 list_for_each_entry_safe(iface, tmp, &iface_list, entry) { 590 list_del(&iface->entry); 591 kfree(iface->name); 592 kfree(iface); 593 } 594 } 595 596 static struct interface *alloc_iface(char *ifname) 597 { 598 struct interface *iface; 599 600 if (!ifname) 601 return NULL; 602 603 iface = kzalloc(sizeof(struct interface), GFP_KERNEL); 604 if (!iface) { 605 kfree(ifname); 606 return NULL; 607 } 608 609 iface->name = ifname; 610 iface->state = IFACE_STATE_DOWN; 611 list_add(&iface->entry, &iface_list); 612 mutex_init(&iface->sock_release_lock); 613 return iface; 614 } 615 616 int ksmbd_tcp_set_interfaces(char *ifc_list, int ifc_list_sz) 617 { 618 int sz = 0; 619 620 if (!ifc_list_sz) { 621 struct net_device *netdev; 622 623 rtnl_lock(); 624 for_each_netdev(&init_net, netdev) { 625 if (netif_is_bridge_port(netdev)) 626 continue; 627 if (!alloc_iface(kstrdup(netdev->name, GFP_KERNEL))) { 628 rtnl_unlock(); 629 return -ENOMEM; 630 } 631 } 632 rtnl_unlock(); 633 bind_additional_ifaces = 1; 634 return 0; 635 } 636 637 while (ifc_list_sz > 0) { 638 if (!alloc_iface(kstrdup(ifc_list, GFP_KERNEL))) 639 return -ENOMEM; 640 641 sz = strlen(ifc_list); 642 if (!sz) 643 break; 644 645 ifc_list += sz + 1; 646 ifc_list_sz -= (sz + 1); 647 } 648 649 bind_additional_ifaces = 0; 650 651 return 0; 652 } 653 654 static const struct ksmbd_transport_ops ksmbd_tcp_transport_ops = { 655 .read = ksmbd_tcp_read, 656 .writev = ksmbd_tcp_writev, 657 .disconnect = ksmbd_tcp_disconnect, 658 }; 659