xref: /linux/fs/smb/server/mgmt/share_config.c (revision 173b0b5b0e865348684c02bd9cb1d22b5d46e458)
1 // SPDX-License-Identifier: GPL-2.0-or-later
2 /*
3  *   Copyright (C) 2018 Samsung Electronics Co., Ltd.
4  */
5 
6 #include <linux/list.h>
7 #include <linux/jhash.h>
8 #include <linux/slab.h>
9 #include <linux/rwsem.h>
10 #include <linux/parser.h>
11 #include <linux/namei.h>
12 #include <linux/sched.h>
13 #include <linux/mm.h>
14 
15 #include "share_config.h"
16 #include "user_config.h"
17 #include "user_session.h"
18 #include "../transport_ipc.h"
19 #include "../misc.h"
20 
21 #define SHARE_HASH_BITS		3
22 static DEFINE_HASHTABLE(shares_table, SHARE_HASH_BITS);
23 static DECLARE_RWSEM(shares_table_lock);
24 
25 struct ksmbd_veto_pattern {
26 	char			*pattern;
27 	struct list_head	list;
28 };
29 
30 static unsigned int share_name_hash(const char *name)
31 {
32 	return jhash(name, strlen(name), 0);
33 }
34 
35 static void kill_share(struct ksmbd_share_config *share)
36 {
37 	while (!list_empty(&share->veto_list)) {
38 		struct ksmbd_veto_pattern *p;
39 
40 		p = list_entry(share->veto_list.next,
41 			       struct ksmbd_veto_pattern,
42 			       list);
43 		list_del(&p->list);
44 		kfree(p->pattern);
45 		kfree(p);
46 	}
47 
48 	if (share->path)
49 		path_put(&share->vfs_path);
50 	kfree(share->name);
51 	kfree(share->path);
52 	kfree(share);
53 }
54 
55 void ksmbd_share_config_del(struct ksmbd_share_config *share)
56 {
57 	down_write(&shares_table_lock);
58 	hash_del(&share->hlist);
59 	up_write(&shares_table_lock);
60 }
61 
62 void __ksmbd_share_config_put(struct ksmbd_share_config *share)
63 {
64 	ksmbd_share_config_del(share);
65 	kill_share(share);
66 }
67 
68 static struct ksmbd_share_config *
69 __get_share_config(struct ksmbd_share_config *share)
70 {
71 	if (!atomic_inc_not_zero(&share->refcount))
72 		return NULL;
73 	return share;
74 }
75 
76 static struct ksmbd_share_config *__share_lookup(const char *name)
77 {
78 	struct ksmbd_share_config *share;
79 	unsigned int key = share_name_hash(name);
80 
81 	hash_for_each_possible(shares_table, share, hlist, key) {
82 		if (!strcmp(name, share->name))
83 			return share;
84 	}
85 	return NULL;
86 }
87 
88 static int parse_veto_list(struct ksmbd_share_config *share,
89 			   char *veto_list,
90 			   int veto_list_sz)
91 {
92 	int sz = 0;
93 
94 	if (!veto_list_sz)
95 		return 0;
96 
97 	while (veto_list_sz > 0) {
98 		struct ksmbd_veto_pattern *p;
99 
100 		sz = strlen(veto_list);
101 		if (!sz)
102 			break;
103 
104 		p = kzalloc(sizeof(struct ksmbd_veto_pattern), GFP_KERNEL);
105 		if (!p)
106 			return -ENOMEM;
107 
108 		p->pattern = kstrdup(veto_list, GFP_KERNEL);
109 		if (!p->pattern) {
110 			kfree(p);
111 			return -ENOMEM;
112 		}
113 
114 		list_add(&p->list, &share->veto_list);
115 
116 		veto_list += sz + 1;
117 		veto_list_sz -= (sz + 1);
118 	}
119 
120 	return 0;
121 }
122 
123 static struct ksmbd_share_config *share_config_request(struct unicode_map *um,
124 						       const char *name)
125 {
126 	struct ksmbd_share_config_response *resp;
127 	struct ksmbd_share_config *share = NULL;
128 	struct ksmbd_share_config *lookup;
129 	int ret;
130 
131 	resp = ksmbd_ipc_share_config_request(name);
132 	if (!resp)
133 		return NULL;
134 
135 	if (resp->flags == KSMBD_SHARE_FLAG_INVALID)
136 		goto out;
137 
138 	if (*resp->share_name) {
139 		char *cf_resp_name;
140 		bool equal;
141 
142 		cf_resp_name = ksmbd_casefold_sharename(um, resp->share_name);
143 		if (IS_ERR(cf_resp_name))
144 			goto out;
145 		equal = !strcmp(cf_resp_name, name);
146 		kfree(cf_resp_name);
147 		if (!equal)
148 			goto out;
149 	}
150 
151 	share = kzalloc(sizeof(struct ksmbd_share_config), GFP_KERNEL);
152 	if (!share)
153 		goto out;
154 
155 	share->flags = resp->flags;
156 	atomic_set(&share->refcount, 1);
157 	INIT_LIST_HEAD(&share->veto_list);
158 	share->name = kstrdup(name, GFP_KERNEL);
159 
160 	if (!test_share_config_flag(share, KSMBD_SHARE_FLAG_PIPE)) {
161 		int path_len = PATH_MAX;
162 
163 		if (resp->payload_sz)
164 			path_len = resp->payload_sz - resp->veto_list_sz;
165 
166 		share->path = kstrndup(ksmbd_share_config_path(resp), path_len,
167 				      GFP_KERNEL);
168 		if (share->path)
169 			share->path_sz = strlen(share->path);
170 		share->create_mask = resp->create_mask;
171 		share->directory_mask = resp->directory_mask;
172 		share->force_create_mode = resp->force_create_mode;
173 		share->force_directory_mode = resp->force_directory_mode;
174 		share->force_uid = resp->force_uid;
175 		share->force_gid = resp->force_gid;
176 		ret = parse_veto_list(share,
177 				      KSMBD_SHARE_CONFIG_VETO_LIST(resp),
178 				      resp->veto_list_sz);
179 		if (!ret && share->path) {
180 			ret = kern_path(share->path, 0, &share->vfs_path);
181 			if (ret) {
182 				ksmbd_debug(SMB, "failed to access '%s'\n",
183 					    share->path);
184 				/* Avoid put_path() */
185 				kfree(share->path);
186 				share->path = NULL;
187 			}
188 		}
189 		if (ret || !share->name) {
190 			kill_share(share);
191 			share = NULL;
192 			goto out;
193 		}
194 	}
195 
196 	down_write(&shares_table_lock);
197 	lookup = __share_lookup(name);
198 	if (lookup)
199 		lookup = __get_share_config(lookup);
200 	if (!lookup) {
201 		hash_add(shares_table, &share->hlist, share_name_hash(name));
202 	} else {
203 		kill_share(share);
204 		share = lookup;
205 	}
206 	up_write(&shares_table_lock);
207 
208 out:
209 	kvfree(resp);
210 	return share;
211 }
212 
213 struct ksmbd_share_config *ksmbd_share_config_get(struct unicode_map *um,
214 						  const char *name)
215 {
216 	struct ksmbd_share_config *share;
217 
218 	down_read(&shares_table_lock);
219 	share = __share_lookup(name);
220 	if (share)
221 		share = __get_share_config(share);
222 	up_read(&shares_table_lock);
223 
224 	if (share)
225 		return share;
226 	return share_config_request(um, name);
227 }
228 
229 bool ksmbd_share_veto_filename(struct ksmbd_share_config *share,
230 			       const char *filename)
231 {
232 	struct ksmbd_veto_pattern *p;
233 
234 	list_for_each_entry(p, &share->veto_list, list) {
235 		if (match_wildcard(p->pattern, filename))
236 			return true;
237 	}
238 	return false;
239 }
240