1 /* 2 * linux/net/sunrpc/auth.c 3 * 4 * Generic RPC client authentication API. 5 * 6 * Copyright (C) 1996, Olaf Kirch <okir@monad.swb.de> 7 */ 8 9 #include <linux/types.h> 10 #include <linux/sched.h> 11 #include <linux/module.h> 12 #include <linux/slab.h> 13 #include <linux/errno.h> 14 #include <linux/sunrpc/clnt.h> 15 #include <linux/spinlock.h> 16 17 #ifdef RPC_DEBUG 18 # define RPCDBG_FACILITY RPCDBG_AUTH 19 #endif 20 21 static struct rpc_authops * auth_flavors[RPC_AUTH_MAXFLAVOR] = { 22 &authnull_ops, /* AUTH_NULL */ 23 &authunix_ops, /* AUTH_UNIX */ 24 NULL, /* others can be loadable modules */ 25 }; 26 27 static u32 28 pseudoflavor_to_flavor(u32 flavor) { 29 if (flavor >= RPC_AUTH_MAXFLAVOR) 30 return RPC_AUTH_GSS; 31 return flavor; 32 } 33 34 int 35 rpcauth_register(struct rpc_authops *ops) 36 { 37 rpc_authflavor_t flavor; 38 39 if ((flavor = ops->au_flavor) >= RPC_AUTH_MAXFLAVOR) 40 return -EINVAL; 41 if (auth_flavors[flavor] != NULL) 42 return -EPERM; /* what else? */ 43 auth_flavors[flavor] = ops; 44 return 0; 45 } 46 47 int 48 rpcauth_unregister(struct rpc_authops *ops) 49 { 50 rpc_authflavor_t flavor; 51 52 if ((flavor = ops->au_flavor) >= RPC_AUTH_MAXFLAVOR) 53 return -EINVAL; 54 if (auth_flavors[flavor] != ops) 55 return -EPERM; /* what else? */ 56 auth_flavors[flavor] = NULL; 57 return 0; 58 } 59 60 struct rpc_auth * 61 rpcauth_create(rpc_authflavor_t pseudoflavor, struct rpc_clnt *clnt) 62 { 63 struct rpc_auth *auth; 64 struct rpc_authops *ops; 65 u32 flavor = pseudoflavor_to_flavor(pseudoflavor); 66 67 auth = ERR_PTR(-EINVAL); 68 if (flavor >= RPC_AUTH_MAXFLAVOR) 69 goto out; 70 71 /* FIXME - auth_flavors[] really needs an rw lock, 72 * and module refcounting. */ 73 #ifdef CONFIG_KMOD 74 if ((ops = auth_flavors[flavor]) == NULL) 75 request_module("rpc-auth-%u", flavor); 76 #endif 77 if ((ops = auth_flavors[flavor]) == NULL) 78 goto out; 79 auth = ops->create(clnt, pseudoflavor); 80 if (IS_ERR(auth)) 81 return auth; 82 if (clnt->cl_auth) 83 rpcauth_destroy(clnt->cl_auth); 84 clnt->cl_auth = auth; 85 86 out: 87 return auth; 88 } 89 90 void 91 rpcauth_destroy(struct rpc_auth *auth) 92 { 93 if (!atomic_dec_and_test(&auth->au_count)) 94 return; 95 auth->au_ops->destroy(auth); 96 } 97 98 static DEFINE_SPINLOCK(rpc_credcache_lock); 99 100 /* 101 * Initialize RPC credential cache 102 */ 103 int 104 rpcauth_init_credcache(struct rpc_auth *auth, unsigned long expire) 105 { 106 struct rpc_cred_cache *new; 107 int i; 108 109 new = kmalloc(sizeof(*new), GFP_KERNEL); 110 if (!new) 111 return -ENOMEM; 112 for (i = 0; i < RPC_CREDCACHE_NR; i++) 113 INIT_HLIST_HEAD(&new->hashtable[i]); 114 new->expire = expire; 115 new->nextgc = jiffies + (expire >> 1); 116 auth->au_credcache = new; 117 return 0; 118 } 119 120 /* 121 * Destroy a list of credentials 122 */ 123 static inline 124 void rpcauth_destroy_credlist(struct hlist_head *head) 125 { 126 struct rpc_cred *cred; 127 128 while (!hlist_empty(head)) { 129 cred = hlist_entry(head->first, struct rpc_cred, cr_hash); 130 hlist_del_init(&cred->cr_hash); 131 put_rpccred(cred); 132 } 133 } 134 135 /* 136 * Clear the RPC credential cache, and delete those credentials 137 * that are not referenced. 138 */ 139 void 140 rpcauth_free_credcache(struct rpc_auth *auth) 141 { 142 struct rpc_cred_cache *cache = auth->au_credcache; 143 HLIST_HEAD(free); 144 struct hlist_node *pos, *next; 145 struct rpc_cred *cred; 146 int i; 147 148 spin_lock(&rpc_credcache_lock); 149 for (i = 0; i < RPC_CREDCACHE_NR; i++) { 150 hlist_for_each_safe(pos, next, &cache->hashtable[i]) { 151 cred = hlist_entry(pos, struct rpc_cred, cr_hash); 152 __hlist_del(&cred->cr_hash); 153 hlist_add_head(&cred->cr_hash, &free); 154 } 155 } 156 spin_unlock(&rpc_credcache_lock); 157 rpcauth_destroy_credlist(&free); 158 } 159 160 static void 161 rpcauth_prune_expired(struct rpc_auth *auth, struct rpc_cred *cred, struct hlist_head *free) 162 { 163 if (atomic_read(&cred->cr_count) != 1) 164 return; 165 if (time_after(jiffies, cred->cr_expire + auth->au_credcache->expire)) 166 cred->cr_flags &= ~RPCAUTH_CRED_UPTODATE; 167 if (!(cred->cr_flags & RPCAUTH_CRED_UPTODATE)) { 168 __hlist_del(&cred->cr_hash); 169 hlist_add_head(&cred->cr_hash, free); 170 } 171 } 172 173 /* 174 * Remove stale credentials. Avoid sleeping inside the loop. 175 */ 176 static void 177 rpcauth_gc_credcache(struct rpc_auth *auth, struct hlist_head *free) 178 { 179 struct rpc_cred_cache *cache = auth->au_credcache; 180 struct hlist_node *pos, *next; 181 struct rpc_cred *cred; 182 int i; 183 184 dprintk("RPC: gc'ing RPC credentials for auth %p\n", auth); 185 for (i = 0; i < RPC_CREDCACHE_NR; i++) { 186 hlist_for_each_safe(pos, next, &cache->hashtable[i]) { 187 cred = hlist_entry(pos, struct rpc_cred, cr_hash); 188 rpcauth_prune_expired(auth, cred, free); 189 } 190 } 191 cache->nextgc = jiffies + cache->expire; 192 } 193 194 /* 195 * Look up a process' credentials in the authentication cache 196 */ 197 struct rpc_cred * 198 rpcauth_lookup_credcache(struct rpc_auth *auth, struct auth_cred * acred, 199 int flags) 200 { 201 struct rpc_cred_cache *cache = auth->au_credcache; 202 HLIST_HEAD(free); 203 struct hlist_node *pos, *next; 204 struct rpc_cred *new = NULL, 205 *cred = NULL; 206 int nr = 0; 207 208 if (!(flags & RPCAUTH_LOOKUP_ROOTCREDS)) 209 nr = acred->uid & RPC_CREDCACHE_MASK; 210 retry: 211 spin_lock(&rpc_credcache_lock); 212 if (time_before(cache->nextgc, jiffies)) 213 rpcauth_gc_credcache(auth, &free); 214 hlist_for_each_safe(pos, next, &cache->hashtable[nr]) { 215 struct rpc_cred *entry; 216 entry = hlist_entry(pos, struct rpc_cred, cr_hash); 217 if (entry->cr_ops->crmatch(acred, entry, flags)) { 218 hlist_del(&entry->cr_hash); 219 cred = entry; 220 break; 221 } 222 rpcauth_prune_expired(auth, entry, &free); 223 } 224 if (new) { 225 if (cred) 226 hlist_add_head(&new->cr_hash, &free); 227 else 228 cred = new; 229 } 230 if (cred) { 231 hlist_add_head(&cred->cr_hash, &cache->hashtable[nr]); 232 get_rpccred(cred); 233 } 234 spin_unlock(&rpc_credcache_lock); 235 236 rpcauth_destroy_credlist(&free); 237 238 if (!cred) { 239 new = auth->au_ops->crcreate(auth, acred, flags); 240 if (!IS_ERR(new)) { 241 #ifdef RPC_DEBUG 242 new->cr_magic = RPCAUTH_CRED_MAGIC; 243 #endif 244 goto retry; 245 } else 246 cred = new; 247 } else if ((cred->cr_flags & RPCAUTH_CRED_NEW) 248 && cred->cr_ops->cr_init != NULL 249 && !(flags & RPCAUTH_LOOKUP_NEW)) { 250 int res = cred->cr_ops->cr_init(auth, cred); 251 if (res < 0) { 252 put_rpccred(cred); 253 cred = ERR_PTR(res); 254 } 255 } 256 257 return (struct rpc_cred *) cred; 258 } 259 260 struct rpc_cred * 261 rpcauth_lookupcred(struct rpc_auth *auth, int flags) 262 { 263 struct auth_cred acred = { 264 .uid = current->fsuid, 265 .gid = current->fsgid, 266 .group_info = current->group_info, 267 }; 268 struct rpc_cred *ret; 269 270 dprintk("RPC: looking up %s cred\n", 271 auth->au_ops->au_name); 272 get_group_info(acred.group_info); 273 ret = auth->au_ops->lookup_cred(auth, &acred, flags); 274 put_group_info(acred.group_info); 275 return ret; 276 } 277 278 struct rpc_cred * 279 rpcauth_bindcred(struct rpc_task *task) 280 { 281 struct rpc_auth *auth = task->tk_auth; 282 struct auth_cred acred = { 283 .uid = current->fsuid, 284 .gid = current->fsgid, 285 .group_info = current->group_info, 286 }; 287 struct rpc_cred *ret; 288 int flags = 0; 289 290 dprintk("RPC: %4d looking up %s cred\n", 291 task->tk_pid, task->tk_auth->au_ops->au_name); 292 get_group_info(acred.group_info); 293 if (task->tk_flags & RPC_TASK_ROOTCREDS) 294 flags |= RPCAUTH_LOOKUP_ROOTCREDS; 295 ret = auth->au_ops->lookup_cred(auth, &acred, flags); 296 if (!IS_ERR(ret)) 297 task->tk_msg.rpc_cred = ret; 298 else 299 task->tk_status = PTR_ERR(ret); 300 put_group_info(acred.group_info); 301 return ret; 302 } 303 304 void 305 rpcauth_holdcred(struct rpc_task *task) 306 { 307 dprintk("RPC: %4d holding %s cred %p\n", 308 task->tk_pid, task->tk_auth->au_ops->au_name, task->tk_msg.rpc_cred); 309 if (task->tk_msg.rpc_cred) 310 get_rpccred(task->tk_msg.rpc_cred); 311 } 312 313 void 314 put_rpccred(struct rpc_cred *cred) 315 { 316 cred->cr_expire = jiffies; 317 if (!atomic_dec_and_test(&cred->cr_count)) 318 return; 319 cred->cr_ops->crdestroy(cred); 320 } 321 322 void 323 rpcauth_unbindcred(struct rpc_task *task) 324 { 325 struct rpc_cred *cred = task->tk_msg.rpc_cred; 326 327 dprintk("RPC: %4d releasing %s cred %p\n", 328 task->tk_pid, task->tk_auth->au_ops->au_name, cred); 329 330 put_rpccred(cred); 331 task->tk_msg.rpc_cred = NULL; 332 } 333 334 u32 * 335 rpcauth_marshcred(struct rpc_task *task, u32 *p) 336 { 337 struct rpc_cred *cred = task->tk_msg.rpc_cred; 338 339 dprintk("RPC: %4d marshaling %s cred %p\n", 340 task->tk_pid, task->tk_auth->au_ops->au_name, cred); 341 342 return cred->cr_ops->crmarshal(task, p); 343 } 344 345 u32 * 346 rpcauth_checkverf(struct rpc_task *task, u32 *p) 347 { 348 struct rpc_cred *cred = task->tk_msg.rpc_cred; 349 350 dprintk("RPC: %4d validating %s cred %p\n", 351 task->tk_pid, task->tk_auth->au_ops->au_name, cred); 352 353 return cred->cr_ops->crvalidate(task, p); 354 } 355 356 int 357 rpcauth_wrap_req(struct rpc_task *task, kxdrproc_t encode, void *rqstp, 358 u32 *data, void *obj) 359 { 360 struct rpc_cred *cred = task->tk_msg.rpc_cred; 361 362 dprintk("RPC: %4d using %s cred %p to wrap rpc data\n", 363 task->tk_pid, cred->cr_ops->cr_name, cred); 364 if (cred->cr_ops->crwrap_req) 365 return cred->cr_ops->crwrap_req(task, encode, rqstp, data, obj); 366 /* By default, we encode the arguments normally. */ 367 return encode(rqstp, data, obj); 368 } 369 370 int 371 rpcauth_unwrap_resp(struct rpc_task *task, kxdrproc_t decode, void *rqstp, 372 u32 *data, void *obj) 373 { 374 struct rpc_cred *cred = task->tk_msg.rpc_cred; 375 376 dprintk("RPC: %4d using %s cred %p to unwrap rpc data\n", 377 task->tk_pid, cred->cr_ops->cr_name, cred); 378 if (cred->cr_ops->crunwrap_resp) 379 return cred->cr_ops->crunwrap_resp(task, decode, rqstp, 380 data, obj); 381 /* By default, we decode the arguments normally. */ 382 return decode(rqstp, data, obj); 383 } 384 385 int 386 rpcauth_refreshcred(struct rpc_task *task) 387 { 388 struct rpc_cred *cred = task->tk_msg.rpc_cred; 389 int err; 390 391 dprintk("RPC: %4d refreshing %s cred %p\n", 392 task->tk_pid, task->tk_auth->au_ops->au_name, cred); 393 394 err = cred->cr_ops->crrefresh(task); 395 if (err < 0) 396 task->tk_status = err; 397 return err; 398 } 399 400 void 401 rpcauth_invalcred(struct rpc_task *task) 402 { 403 dprintk("RPC: %4d invalidating %s cred %p\n", 404 task->tk_pid, task->tk_auth->au_ops->au_name, task->tk_msg.rpc_cred); 405 spin_lock(&rpc_credcache_lock); 406 if (task->tk_msg.rpc_cred) 407 task->tk_msg.rpc_cred->cr_flags &= ~RPCAUTH_CRED_UPTODATE; 408 spin_unlock(&rpc_credcache_lock); 409 } 410 411 int 412 rpcauth_uptodatecred(struct rpc_task *task) 413 { 414 return !(task->tk_msg.rpc_cred) || 415 (task->tk_msg.rpc_cred->cr_flags & RPCAUTH_CRED_UPTODATE); 416 } 417