1 // SPDX-License-Identifier: GPL-2.0-or-later 2 /* AFS vlserver list management. 3 * 4 * Copyright (C) 2018 Red Hat, Inc. All Rights Reserved. 5 * Written by David Howells (dhowells@redhat.com) 6 */ 7 8 #include <linux/kernel.h> 9 #include <linux/slab.h> 10 #include "internal.h" 11 12 struct afs_vlserver *afs_alloc_vlserver(const char *name, size_t name_len, 13 unsigned short port) 14 { 15 struct afs_vlserver *vlserver; 16 static atomic_t debug_ids; 17 18 vlserver = kzalloc(struct_size(vlserver, name, name_len + 1), 19 GFP_KERNEL); 20 if (vlserver) { 21 refcount_set(&vlserver->ref, 1); 22 rwlock_init(&vlserver->lock); 23 init_waitqueue_head(&vlserver->probe_wq); 24 spin_lock_init(&vlserver->probe_lock); 25 vlserver->debug_id = atomic_inc_return(&debug_ids); 26 vlserver->rtt = UINT_MAX; 27 vlserver->name_len = name_len; 28 vlserver->service_id = VL_SERVICE; 29 vlserver->port = port; 30 memcpy(vlserver->name, name, name_len); 31 } 32 return vlserver; 33 } 34 35 static void afs_vlserver_rcu(struct rcu_head *rcu) 36 { 37 struct afs_vlserver *vlserver = container_of(rcu, struct afs_vlserver, rcu); 38 39 afs_put_addrlist(rcu_access_pointer(vlserver->addresses), 40 afs_alist_trace_put_vlserver); 41 kfree_rcu(vlserver, rcu); 42 } 43 44 void afs_put_vlserver(struct afs_net *net, struct afs_vlserver *vlserver) 45 { 46 if (vlserver && 47 refcount_dec_and_test(&vlserver->ref)) 48 call_rcu(&vlserver->rcu, afs_vlserver_rcu); 49 } 50 51 struct afs_vlserver_list *afs_alloc_vlserver_list(unsigned int nr_servers) 52 { 53 struct afs_vlserver_list *vllist; 54 55 vllist = kzalloc(struct_size(vllist, servers, nr_servers), GFP_KERNEL); 56 if (vllist) { 57 refcount_set(&vllist->ref, 1); 58 rwlock_init(&vllist->lock); 59 } 60 61 return vllist; 62 } 63 64 void afs_put_vlserverlist(struct afs_net *net, struct afs_vlserver_list *vllist) 65 { 66 if (vllist) { 67 if (refcount_dec_and_test(&vllist->ref)) { 68 int i; 69 70 for (i = 0; i < vllist->nr_servers; i++) { 71 afs_put_vlserver(net, vllist->servers[i].server); 72 } 73 kfree_rcu(vllist, rcu); 74 } 75 } 76 } 77 78 static u16 afs_extract_le16(const u8 **_b) 79 { 80 u16 val; 81 82 val = (u16)*(*_b)++ << 0; 83 val |= (u16)*(*_b)++ << 8; 84 return val; 85 } 86 87 /* 88 * Build a VL server address list from a DNS queried server list. 89 */ 90 static struct afs_addr_list *afs_extract_vl_addrs(struct afs_net *net, 91 const u8 **_b, const u8 *end, 92 u8 nr_addrs, u16 port) 93 { 94 struct afs_addr_list *alist; 95 const u8 *b = *_b; 96 int ret = -EINVAL; 97 98 alist = afs_alloc_addrlist(nr_addrs); 99 if (!alist) 100 return ERR_PTR(-ENOMEM); 101 if (nr_addrs == 0) 102 return alist; 103 104 for (; nr_addrs > 0 && end - b >= nr_addrs; nr_addrs--) { 105 struct dns_server_list_v1_address hdr; 106 __be32 x[4]; 107 108 hdr.address_type = *b++; 109 110 switch (hdr.address_type) { 111 case DNS_ADDRESS_IS_IPV4: 112 if (end - b < 4) { 113 _leave(" = -EINVAL [short inet]"); 114 goto error; 115 } 116 memcpy(x, b, 4); 117 ret = afs_merge_fs_addr4(net, alist, x[0], port); 118 if (ret < 0) 119 goto error; 120 b += 4; 121 break; 122 123 case DNS_ADDRESS_IS_IPV6: 124 if (end - b < 16) { 125 _leave(" = -EINVAL [short inet6]"); 126 goto error; 127 } 128 memcpy(x, b, 16); 129 ret = afs_merge_fs_addr6(net, alist, x, port); 130 if (ret < 0) 131 goto error; 132 b += 16; 133 break; 134 135 default: 136 _leave(" = -EADDRNOTAVAIL [unknown af %u]", 137 hdr.address_type); 138 ret = -EADDRNOTAVAIL; 139 goto error; 140 } 141 } 142 143 /* Start with IPv6 if available. */ 144 if (alist->nr_ipv4 < alist->nr_addrs) 145 alist->preferred = alist->nr_ipv4; 146 147 *_b = b; 148 return alist; 149 150 error: 151 *_b = b; 152 afs_put_addrlist(alist, afs_alist_trace_put_parse_error); 153 return ERR_PTR(ret); 154 } 155 156 /* 157 * Build a VL server list from a DNS queried server list. 158 */ 159 struct afs_vlserver_list *afs_extract_vlserver_list(struct afs_cell *cell, 160 const void *buffer, 161 size_t buffer_size) 162 { 163 const struct dns_server_list_v1_header *hdr = buffer; 164 struct dns_server_list_v1_server bs; 165 struct afs_vlserver_list *vllist, *previous; 166 struct afs_addr_list *addrs; 167 struct afs_vlserver *server; 168 const u8 *b = buffer, *end = buffer + buffer_size; 169 int ret = -ENOMEM, nr_servers, i, j; 170 171 _enter(""); 172 173 /* Check that it's a server list, v1 */ 174 if (end - b < sizeof(*hdr) || 175 hdr->hdr.content != DNS_PAYLOAD_IS_SERVER_LIST || 176 hdr->hdr.version != 1) { 177 pr_notice("kAFS: Got DNS record [%u,%u] len %zu\n", 178 hdr->hdr.content, hdr->hdr.version, end - b); 179 ret = -EDESTADDRREQ; 180 goto dump; 181 } 182 183 nr_servers = hdr->nr_servers; 184 185 vllist = afs_alloc_vlserver_list(nr_servers); 186 if (!vllist) 187 return ERR_PTR(-ENOMEM); 188 189 vllist->source = (hdr->source < NR__dns_record_source) ? 190 hdr->source : NR__dns_record_source; 191 vllist->status = (hdr->status < NR__dns_lookup_status) ? 192 hdr->status : NR__dns_lookup_status; 193 194 read_lock(&cell->vl_servers_lock); 195 previous = afs_get_vlserverlist( 196 rcu_dereference_protected(cell->vl_servers, 197 lockdep_is_held(&cell->vl_servers_lock))); 198 read_unlock(&cell->vl_servers_lock); 199 200 b += sizeof(*hdr); 201 while (end - b >= sizeof(bs)) { 202 bs.name_len = afs_extract_le16(&b); 203 bs.priority = afs_extract_le16(&b); 204 bs.weight = afs_extract_le16(&b); 205 bs.port = afs_extract_le16(&b); 206 bs.source = *b++; 207 bs.status = *b++; 208 bs.protocol = *b++; 209 bs.nr_addrs = *b++; 210 211 _debug("extract %u %u %u %u %u %u %*.*s", 212 bs.name_len, bs.priority, bs.weight, 213 bs.port, bs.protocol, bs.nr_addrs, 214 bs.name_len, bs.name_len, b); 215 216 if (end - b < bs.name_len) 217 break; 218 219 ret = -EPROTONOSUPPORT; 220 if (bs.protocol == DNS_SERVER_PROTOCOL_UNSPECIFIED) { 221 bs.protocol = DNS_SERVER_PROTOCOL_UDP; 222 } else if (bs.protocol != DNS_SERVER_PROTOCOL_UDP) { 223 _leave(" = [proto %u]", bs.protocol); 224 goto error; 225 } 226 227 if (bs.port == 0) 228 bs.port = AFS_VL_PORT; 229 if (bs.source > NR__dns_record_source) 230 bs.source = NR__dns_record_source; 231 if (bs.status > NR__dns_lookup_status) 232 bs.status = NR__dns_lookup_status; 233 234 /* See if we can update an old server record */ 235 server = NULL; 236 for (i = 0; i < previous->nr_servers; i++) { 237 struct afs_vlserver *p = previous->servers[i].server; 238 239 if (p->name_len == bs.name_len && 240 p->port == bs.port && 241 strncasecmp(b, p->name, bs.name_len) == 0) { 242 server = afs_get_vlserver(p); 243 break; 244 } 245 } 246 247 if (!server) { 248 ret = -ENOMEM; 249 server = afs_alloc_vlserver(b, bs.name_len, bs.port); 250 if (!server) 251 goto error; 252 } 253 254 b += bs.name_len; 255 256 /* Extract the addresses - note that we can't skip this as we 257 * have to advance the payload pointer. 258 */ 259 addrs = afs_extract_vl_addrs(cell->net, &b, end, bs.nr_addrs, bs.port); 260 if (IS_ERR(addrs)) { 261 ret = PTR_ERR(addrs); 262 goto error_2; 263 } 264 265 if (vllist->nr_servers >= nr_servers) { 266 _debug("skip %u >= %u", vllist->nr_servers, nr_servers); 267 afs_put_addrlist(addrs, afs_alist_trace_put_parse_empty); 268 afs_put_vlserver(cell->net, server); 269 continue; 270 } 271 272 addrs->source = bs.source; 273 addrs->status = bs.status; 274 275 if (addrs->nr_addrs == 0) { 276 afs_put_addrlist(addrs, afs_alist_trace_put_parse_empty); 277 if (!rcu_access_pointer(server->addresses)) { 278 afs_put_vlserver(cell->net, server); 279 continue; 280 } 281 } else { 282 struct afs_addr_list *old = addrs; 283 284 write_lock(&server->lock); 285 old = rcu_replace_pointer(server->addresses, old, 286 lockdep_is_held(&server->lock)); 287 write_unlock(&server->lock); 288 afs_put_addrlist(old, afs_alist_trace_put_vlserver_old); 289 } 290 291 292 /* TODO: Might want to check for duplicates */ 293 294 /* Insertion-sort by priority and weight */ 295 for (j = 0; j < vllist->nr_servers; j++) { 296 if (bs.priority < vllist->servers[j].priority) 297 break; /* Lower preferable */ 298 if (bs.priority == vllist->servers[j].priority && 299 bs.weight > vllist->servers[j].weight) 300 break; /* Higher preferable */ 301 } 302 303 if (j < vllist->nr_servers) { 304 memmove(vllist->servers + j + 1, 305 vllist->servers + j, 306 (vllist->nr_servers - j) * sizeof(struct afs_vlserver_entry)); 307 } 308 309 clear_bit(AFS_VLSERVER_FL_PROBED, &server->flags); 310 311 vllist->servers[j].priority = bs.priority; 312 vllist->servers[j].weight = bs.weight; 313 vllist->servers[j].server = server; 314 vllist->nr_servers++; 315 } 316 317 if (b != end) { 318 _debug("parse error %zd", b - end); 319 goto error; 320 } 321 322 afs_put_vlserverlist(cell->net, previous); 323 _leave(" = ok [%u]", vllist->nr_servers); 324 return vllist; 325 326 error_2: 327 afs_put_vlserver(cell->net, server); 328 error: 329 afs_put_vlserverlist(cell->net, vllist); 330 afs_put_vlserverlist(cell->net, previous); 331 dump: 332 if (ret != -ENOMEM) { 333 printk(KERN_DEBUG "DNS: at %zu\n", (const void *)b - buffer); 334 print_hex_dump_bytes("DNS: ", DUMP_PREFIX_NONE, buffer, buffer_size); 335 } 336 return ERR_PTR(ret); 337 } 338