1 // SPDX-License-Identifier: GPL-2.0-or-later 2 /* AFS vlserver list management. 3 * 4 * Copyright (C) 2018 Red Hat, Inc. All Rights Reserved. 5 * Written by David Howells (dhowells@redhat.com) 6 */ 7 8 #include <linux/kernel.h> 9 #include <linux/slab.h> 10 #include "internal.h" 11 12 struct afs_vlserver *afs_alloc_vlserver(const char *name, size_t name_len, 13 unsigned short port) 14 { 15 struct afs_vlserver *vlserver; 16 static atomic_t debug_ids; 17 18 vlserver = kzalloc_flex(*vlserver, name, name_len + 1); 19 if (vlserver) { 20 refcount_set(&vlserver->ref, 1); 21 rwlock_init(&vlserver->lock); 22 init_waitqueue_head(&vlserver->probe_wq); 23 spin_lock_init(&vlserver->probe_lock); 24 vlserver->debug_id = atomic_inc_return(&debug_ids); 25 vlserver->rtt = UINT_MAX; 26 vlserver->name_len = name_len; 27 vlserver->service_id = VL_SERVICE; 28 vlserver->port = port; 29 memcpy(vlserver->name, name, name_len); 30 } 31 return vlserver; 32 } 33 34 static void afs_vlserver_rcu(struct rcu_head *rcu) 35 { 36 struct afs_vlserver *vlserver = container_of(rcu, struct afs_vlserver, rcu); 37 38 afs_put_addrlist(rcu_access_pointer(vlserver->addresses), 39 afs_alist_trace_put_vlserver); 40 kfree_rcu(vlserver, rcu); 41 } 42 43 void afs_put_vlserver(struct afs_net *net, struct afs_vlserver *vlserver) 44 { 45 if (vlserver && 46 refcount_dec_and_test(&vlserver->ref)) 47 call_rcu(&vlserver->rcu, afs_vlserver_rcu); 48 } 49 50 struct afs_vlserver_list *afs_alloc_vlserver_list(unsigned int nr_servers) 51 { 52 struct afs_vlserver_list *vllist; 53 54 vllist = kzalloc_flex(*vllist, servers, nr_servers); 55 if (vllist) { 56 refcount_set(&vllist->ref, 1); 57 rwlock_init(&vllist->lock); 58 } 59 60 return vllist; 61 } 62 63 void afs_put_vlserverlist(struct afs_net *net, struct afs_vlserver_list *vllist) 64 { 65 if (vllist) { 66 if (refcount_dec_and_test(&vllist->ref)) { 67 int i; 68 69 for (i = 0; i < vllist->nr_servers; i++) { 70 afs_put_vlserver(net, vllist->servers[i].server); 71 } 72 kfree_rcu(vllist, rcu); 73 } 74 } 75 } 76 77 static u16 afs_extract_le16(const u8 **_b) 78 { 79 u16 val; 80 81 val = (u16)*(*_b)++ << 0; 82 val |= (u16)*(*_b)++ << 8; 83 return val; 84 } 85 86 /* 87 * Build a VL server address list from a DNS queried server list. 88 */ 89 static struct afs_addr_list *afs_extract_vl_addrs(struct afs_net *net, 90 const u8 **_b, const u8 *end, 91 u8 nr_addrs, u16 port) 92 { 93 struct afs_addr_list *alist; 94 const u8 *b = *_b; 95 int ret = -EINVAL; 96 97 alist = afs_alloc_addrlist(nr_addrs); 98 if (!alist) 99 return ERR_PTR(-ENOMEM); 100 if (nr_addrs == 0) 101 return alist; 102 103 for (; nr_addrs > 0 && end - b >= nr_addrs; nr_addrs--) { 104 struct dns_server_list_v1_address hdr; 105 __be32 x[4]; 106 107 hdr.address_type = *b++; 108 109 switch (hdr.address_type) { 110 case DNS_ADDRESS_IS_IPV4: 111 if (end - b < 4) { 112 _leave(" = -EINVAL [short inet]"); 113 goto error; 114 } 115 memcpy(x, b, 4); 116 ret = afs_merge_fs_addr4(net, alist, x[0], port); 117 if (ret < 0) 118 goto error; 119 b += 4; 120 break; 121 122 case DNS_ADDRESS_IS_IPV6: 123 if (end - b < 16) { 124 _leave(" = -EINVAL [short inet6]"); 125 goto error; 126 } 127 memcpy(x, b, 16); 128 ret = afs_merge_fs_addr6(net, alist, x, port); 129 if (ret < 0) 130 goto error; 131 b += 16; 132 break; 133 134 default: 135 _leave(" = -EADDRNOTAVAIL [unknown af %u]", 136 hdr.address_type); 137 ret = -EADDRNOTAVAIL; 138 goto error; 139 } 140 } 141 142 /* Start with IPv6 if available. */ 143 if (alist->nr_ipv4 < alist->nr_addrs) 144 alist->preferred = alist->nr_ipv4; 145 146 *_b = b; 147 return alist; 148 149 error: 150 *_b = b; 151 afs_put_addrlist(alist, afs_alist_trace_put_parse_error); 152 return ERR_PTR(ret); 153 } 154 155 /* 156 * Build a VL server list from a DNS queried server list. 157 */ 158 struct afs_vlserver_list *afs_extract_vlserver_list(struct afs_cell *cell, 159 const void *buffer, 160 size_t buffer_size) 161 { 162 const struct dns_server_list_v1_header *hdr = buffer; 163 struct dns_server_list_v1_server bs; 164 struct afs_vlserver_list *vllist, *previous; 165 struct afs_addr_list *addrs; 166 struct afs_vlserver *server; 167 const u8 *b = buffer, *end = buffer + buffer_size; 168 int ret = -ENOMEM, nr_servers, i, j; 169 170 _enter(""); 171 172 /* Check that it's a server list, v1 */ 173 if (end - b < sizeof(*hdr) || 174 hdr->hdr.content != DNS_PAYLOAD_IS_SERVER_LIST || 175 hdr->hdr.version != 1) { 176 pr_notice("kAFS: Got DNS record [%u,%u] len %zu\n", 177 hdr->hdr.content, hdr->hdr.version, end - b); 178 ret = -EDESTADDRREQ; 179 goto dump; 180 } 181 182 nr_servers = hdr->nr_servers; 183 184 vllist = afs_alloc_vlserver_list(nr_servers); 185 if (!vllist) 186 return ERR_PTR(-ENOMEM); 187 188 vllist->source = (hdr->source < NR__dns_record_source) ? 189 hdr->source : NR__dns_record_source; 190 vllist->status = (hdr->status < NR__dns_lookup_status) ? 191 hdr->status : NR__dns_lookup_status; 192 193 read_lock(&cell->vl_servers_lock); 194 previous = afs_get_vlserverlist( 195 rcu_dereference_protected(cell->vl_servers, 196 lockdep_is_held(&cell->vl_servers_lock))); 197 read_unlock(&cell->vl_servers_lock); 198 199 b += sizeof(*hdr); 200 while (end - b >= sizeof(bs)) { 201 bs.name_len = afs_extract_le16(&b); 202 bs.priority = afs_extract_le16(&b); 203 bs.weight = afs_extract_le16(&b); 204 bs.port = afs_extract_le16(&b); 205 bs.source = *b++; 206 bs.status = *b++; 207 bs.protocol = *b++; 208 bs.nr_addrs = *b++; 209 210 _debug("extract %u %u %u %u %u %u %*.*s", 211 bs.name_len, bs.priority, bs.weight, 212 bs.port, bs.protocol, bs.nr_addrs, 213 bs.name_len, bs.name_len, b); 214 215 if (end - b < bs.name_len) 216 break; 217 218 ret = -EPROTONOSUPPORT; 219 if (bs.protocol == DNS_SERVER_PROTOCOL_UNSPECIFIED) { 220 bs.protocol = DNS_SERVER_PROTOCOL_UDP; 221 } else if (bs.protocol != DNS_SERVER_PROTOCOL_UDP) { 222 _leave(" = [proto %u]", bs.protocol); 223 goto error; 224 } 225 226 if (bs.port == 0) 227 bs.port = AFS_VL_PORT; 228 if (bs.source > NR__dns_record_source) 229 bs.source = NR__dns_record_source; 230 if (bs.status > NR__dns_lookup_status) 231 bs.status = NR__dns_lookup_status; 232 233 /* See if we can update an old server record */ 234 server = NULL; 235 for (i = 0; i < previous->nr_servers; i++) { 236 struct afs_vlserver *p = previous->servers[i].server; 237 238 if (p->name_len == bs.name_len && 239 p->port == bs.port && 240 strncasecmp(b, p->name, bs.name_len) == 0) { 241 server = afs_get_vlserver(p); 242 break; 243 } 244 } 245 246 if (!server) { 247 ret = -ENOMEM; 248 server = afs_alloc_vlserver(b, bs.name_len, bs.port); 249 if (!server) 250 goto error; 251 } 252 253 b += bs.name_len; 254 255 /* Extract the addresses - note that we can't skip this as we 256 * have to advance the payload pointer. 257 */ 258 addrs = afs_extract_vl_addrs(cell->net, &b, end, bs.nr_addrs, bs.port); 259 if (IS_ERR(addrs)) { 260 ret = PTR_ERR(addrs); 261 goto error_2; 262 } 263 264 if (vllist->nr_servers >= nr_servers) { 265 _debug("skip %u >= %u", vllist->nr_servers, nr_servers); 266 afs_put_addrlist(addrs, afs_alist_trace_put_parse_empty); 267 afs_put_vlserver(cell->net, server); 268 continue; 269 } 270 271 addrs->source = bs.source; 272 addrs->status = bs.status; 273 274 if (addrs->nr_addrs == 0) { 275 afs_put_addrlist(addrs, afs_alist_trace_put_parse_empty); 276 if (!rcu_access_pointer(server->addresses)) { 277 afs_put_vlserver(cell->net, server); 278 continue; 279 } 280 } else { 281 struct afs_addr_list *old = addrs; 282 283 write_lock(&server->lock); 284 old = rcu_replace_pointer(server->addresses, old, 285 lockdep_is_held(&server->lock)); 286 write_unlock(&server->lock); 287 afs_put_addrlist(old, afs_alist_trace_put_vlserver_old); 288 } 289 290 291 /* TODO: Might want to check for duplicates */ 292 293 /* Insertion-sort by priority and weight */ 294 for (j = 0; j < vllist->nr_servers; j++) { 295 if (bs.priority < vllist->servers[j].priority) 296 break; /* Lower preferable */ 297 if (bs.priority == vllist->servers[j].priority && 298 bs.weight > vllist->servers[j].weight) 299 break; /* Higher preferable */ 300 } 301 302 if (j < vllist->nr_servers) { 303 memmove(vllist->servers + j + 1, 304 vllist->servers + j, 305 (vllist->nr_servers - j) * sizeof(struct afs_vlserver_entry)); 306 } 307 308 clear_bit(AFS_VLSERVER_FL_PROBED, &server->flags); 309 310 vllist->servers[j].priority = bs.priority; 311 vllist->servers[j].weight = bs.weight; 312 vllist->servers[j].server = server; 313 vllist->nr_servers++; 314 } 315 316 if (b != end) { 317 _debug("parse error %zd", b - end); 318 goto error; 319 } 320 321 afs_put_vlserverlist(cell->net, previous); 322 _leave(" = ok [%u]", vllist->nr_servers); 323 return vllist; 324 325 error_2: 326 afs_put_vlserver(cell->net, server); 327 error: 328 afs_put_vlserverlist(cell->net, vllist); 329 afs_put_vlserverlist(cell->net, previous); 330 dump: 331 if (ret != -ENOMEM) { 332 printk(KERN_DEBUG "DNS: at %zu\n", (const void *)b - buffer); 333 print_hex_dump_bytes("DNS: ", DUMP_PREFIX_NONE, buffer, buffer_size); 334 } 335 return ERR_PTR(ret); 336 } 337