xref: /linux/net/mctp/neigh.c (revision 6331b8765cd0634a4e4cdcc1a6f1a74196616b94)
1 // SPDX-License-Identifier: GPL-2.0
2 /*
3  * Management Component Transport Protocol (MCTP) - routing
4  * implementation.
5  *
6  * This is currently based on a simple routing table, with no dst cache. The
7  * number of routes should stay fairly small, so the lookup cost is small.
8  *
9  * Copyright (c) 2021 Code Construct
10  * Copyright (c) 2021 Google
11  */
12 
13 #include <linux/idr.h>
14 #include <linux/mctp.h>
15 #include <linux/netdevice.h>
16 #include <linux/rtnetlink.h>
17 #include <linux/skbuff.h>
18 
19 #include <net/mctp.h>
20 #include <net/mctpdevice.h>
21 #include <net/netlink.h>
22 #include <net/sock.h>
23 
24 static int mctp_neigh_add(struct mctp_dev *mdev, mctp_eid_t eid,
25 			  enum mctp_neigh_source source,
26 			  size_t lladdr_len, const void *lladdr)
27 {
28 	struct net *net = dev_net(mdev->dev);
29 	struct mctp_neigh *neigh;
30 	int rc;
31 
32 	mutex_lock(&net->mctp.neigh_lock);
33 	if (mctp_neigh_lookup(mdev, eid, NULL) == 0) {
34 		rc = -EEXIST;
35 		goto out;
36 	}
37 
38 	if (lladdr_len > sizeof(neigh->ha)) {
39 		rc = -EINVAL;
40 		goto out;
41 	}
42 
43 	neigh = kzalloc(sizeof(*neigh), GFP_KERNEL);
44 	if (!neigh) {
45 		rc = -ENOMEM;
46 		goto out;
47 	}
48 	INIT_LIST_HEAD(&neigh->list);
49 	neigh->dev = mdev;
50 	mctp_dev_hold(neigh->dev);
51 	neigh->eid = eid;
52 	neigh->source = source;
53 	memcpy(neigh->ha, lladdr, lladdr_len);
54 
55 	list_add_rcu(&neigh->list, &net->mctp.neighbours);
56 	rc = 0;
57 out:
58 	mutex_unlock(&net->mctp.neigh_lock);
59 	return rc;
60 }
61 
62 static void __mctp_neigh_free(struct rcu_head *rcu)
63 {
64 	struct mctp_neigh *neigh = container_of(rcu, struct mctp_neigh, rcu);
65 
66 	mctp_dev_put(neigh->dev);
67 	kfree(neigh);
68 }
69 
70 /* Removes all neighbour entries referring to a device */
71 void mctp_neigh_remove_dev(struct mctp_dev *mdev)
72 {
73 	struct net *net = dev_net(mdev->dev);
74 	struct mctp_neigh *neigh, *tmp;
75 
76 	mutex_lock(&net->mctp.neigh_lock);
77 	list_for_each_entry_safe(neigh, tmp, &net->mctp.neighbours, list) {
78 		if (neigh->dev == mdev) {
79 			list_del_rcu(&neigh->list);
80 			/* TODO: immediate RTM_DELNEIGH */
81 			call_rcu(&neigh->rcu, __mctp_neigh_free);
82 		}
83 	}
84 
85 	mutex_unlock(&net->mctp.neigh_lock);
86 }
87 
88 // TODO: add a "source" flag so netlink can only delete static neighbours?
89 static int mctp_neigh_remove(struct mctp_dev *mdev, mctp_eid_t eid)
90 {
91 	struct net *net = dev_net(mdev->dev);
92 	struct mctp_neigh *neigh, *tmp;
93 	bool dropped = false;
94 
95 	mutex_lock(&net->mctp.neigh_lock);
96 	list_for_each_entry_safe(neigh, tmp, &net->mctp.neighbours, list) {
97 		if (neigh->dev == mdev && neigh->eid == eid) {
98 			list_del_rcu(&neigh->list);
99 			/* TODO: immediate RTM_DELNEIGH */
100 			call_rcu(&neigh->rcu, __mctp_neigh_free);
101 			dropped = true;
102 		}
103 	}
104 
105 	mutex_unlock(&net->mctp.neigh_lock);
106 	return dropped ? 0 : -ENOENT;
107 }
108 
109 static const struct nla_policy nd_mctp_policy[NDA_MAX + 1] = {
110 	[NDA_DST]		= { .type = NLA_U8 },
111 	[NDA_LLADDR]		= { .type = NLA_BINARY, .len = MAX_ADDR_LEN },
112 };
113 
114 static int mctp_rtm_newneigh(struct sk_buff *skb, struct nlmsghdr *nlh,
115 			     struct netlink_ext_ack *extack)
116 {
117 	struct net *net = sock_net(skb->sk);
118 	struct net_device *dev;
119 	struct mctp_dev *mdev;
120 	struct ndmsg *ndm;
121 	struct nlattr *tb[NDA_MAX + 1];
122 	int rc;
123 	mctp_eid_t eid;
124 	void *lladdr;
125 	int lladdr_len;
126 
127 	rc = nlmsg_parse(nlh, sizeof(*ndm), tb, NDA_MAX, nd_mctp_policy,
128 			 extack);
129 	if (rc < 0) {
130 		NL_SET_ERR_MSG(extack, "lladdr too large?");
131 		return rc;
132 	}
133 
134 	if (!tb[NDA_DST]) {
135 		NL_SET_ERR_MSG(extack, "Neighbour EID must be specified");
136 		return -EINVAL;
137 	}
138 
139 	if (!tb[NDA_LLADDR]) {
140 		NL_SET_ERR_MSG(extack, "Neighbour lladdr must be specified");
141 		return -EINVAL;
142 	}
143 
144 	eid = nla_get_u8(tb[NDA_DST]);
145 	if (!mctp_address_ok(eid)) {
146 		NL_SET_ERR_MSG(extack, "Invalid neighbour EID");
147 		return -EINVAL;
148 	}
149 
150 	lladdr = nla_data(tb[NDA_LLADDR]);
151 	lladdr_len = nla_len(tb[NDA_LLADDR]);
152 
153 	ndm = nlmsg_data(nlh);
154 
155 	dev = __dev_get_by_index(net, ndm->ndm_ifindex);
156 	if (!dev)
157 		return -ENODEV;
158 
159 	mdev = mctp_dev_get_rtnl(dev);
160 	if (!mdev)
161 		return -ENODEV;
162 
163 	if (lladdr_len != dev->addr_len) {
164 		NL_SET_ERR_MSG(extack, "Wrong lladdr length");
165 		return -EINVAL;
166 	}
167 
168 	return mctp_neigh_add(mdev, eid, MCTP_NEIGH_STATIC,
169 			lladdr_len, lladdr);
170 }
171 
172 static int mctp_rtm_delneigh(struct sk_buff *skb, struct nlmsghdr *nlh,
173 			     struct netlink_ext_ack *extack)
174 {
175 	struct net *net = sock_net(skb->sk);
176 	struct nlattr *tb[NDA_MAX + 1];
177 	struct net_device *dev;
178 	struct mctp_dev *mdev;
179 	struct ndmsg *ndm;
180 	int rc;
181 	mctp_eid_t eid;
182 
183 	rc = nlmsg_parse(nlh, sizeof(*ndm), tb, NDA_MAX, nd_mctp_policy,
184 			 extack);
185 	if (rc < 0) {
186 		NL_SET_ERR_MSG(extack, "incorrect format");
187 		return rc;
188 	}
189 
190 	if (!tb[NDA_DST]) {
191 		NL_SET_ERR_MSG(extack, "Neighbour EID must be specified");
192 		return -EINVAL;
193 	}
194 	eid = nla_get_u8(tb[NDA_DST]);
195 
196 	ndm = nlmsg_data(nlh);
197 	dev = __dev_get_by_index(net, ndm->ndm_ifindex);
198 	if (!dev)
199 		return -ENODEV;
200 
201 	mdev = mctp_dev_get_rtnl(dev);
202 	if (!mdev)
203 		return -ENODEV;
204 
205 	return mctp_neigh_remove(mdev, eid);
206 }
207 
208 static int mctp_fill_neigh(struct sk_buff *skb, u32 portid, u32 seq, int event,
209 			   unsigned int flags, struct mctp_neigh *neigh)
210 {
211 	struct net_device *dev = neigh->dev->dev;
212 	struct nlmsghdr *nlh;
213 	struct ndmsg *hdr;
214 
215 	nlh = nlmsg_put(skb, portid, seq, event, sizeof(*hdr), flags);
216 	if (!nlh)
217 		return -EMSGSIZE;
218 
219 	hdr = nlmsg_data(nlh);
220 	hdr->ndm_family = AF_MCTP;
221 	hdr->ndm_ifindex = dev->ifindex;
222 	hdr->ndm_state = 0; // TODO other state bits?
223 	if (neigh->source == MCTP_NEIGH_STATIC)
224 		hdr->ndm_state |= NUD_PERMANENT;
225 	hdr->ndm_flags = 0;
226 	hdr->ndm_type = RTN_UNICAST; // TODO: is loopback RTN_LOCAL?
227 
228 	if (nla_put_u8(skb, NDA_DST, neigh->eid))
229 		goto cancel;
230 
231 	if (nla_put(skb, NDA_LLADDR, dev->addr_len, neigh->ha))
232 		goto cancel;
233 
234 	nlmsg_end(skb, nlh);
235 
236 	return 0;
237 cancel:
238 	nlmsg_cancel(skb, nlh);
239 	return -EMSGSIZE;
240 }
241 
242 static int mctp_rtm_getneigh(struct sk_buff *skb, struct netlink_callback *cb)
243 {
244 	struct net *net = sock_net(skb->sk);
245 	int rc, idx, req_ifindex;
246 	struct mctp_neigh *neigh;
247 	struct ndmsg *ndmsg;
248 	struct {
249 		int idx;
250 	} *cbctx = (void *)cb->ctx;
251 
252 	ndmsg = nlmsg_data(cb->nlh);
253 	req_ifindex = ndmsg->ndm_ifindex;
254 
255 	idx = 0;
256 	rcu_read_lock();
257 	list_for_each_entry_rcu(neigh, &net->mctp.neighbours, list) {
258 		if (idx < cbctx->idx)
259 			goto cont;
260 
261 		rc = 0;
262 		if (req_ifindex == 0 || req_ifindex == neigh->dev->dev->ifindex)
263 			rc = mctp_fill_neigh(skb, NETLINK_CB(cb->skb).portid,
264 					     cb->nlh->nlmsg_seq,
265 					     RTM_NEWNEIGH, NLM_F_MULTI, neigh);
266 
267 		if (rc)
268 			break;
269 cont:
270 		idx++;
271 	}
272 	rcu_read_unlock();
273 
274 	cbctx->idx = idx;
275 	return skb->len;
276 }
277 
278 int mctp_neigh_lookup(struct mctp_dev *mdev, mctp_eid_t eid, void *ret_hwaddr)
279 {
280 	struct net *net = dev_net(mdev->dev);
281 	struct mctp_neigh *neigh;
282 	int rc = -EHOSTUNREACH; // TODO: or ENOENT?
283 
284 	rcu_read_lock();
285 	list_for_each_entry_rcu(neigh, &net->mctp.neighbours, list) {
286 		if (mdev == neigh->dev && eid == neigh->eid) {
287 			if (ret_hwaddr)
288 				memcpy(ret_hwaddr, neigh->ha,
289 				       sizeof(neigh->ha));
290 			rc = 0;
291 			break;
292 		}
293 	}
294 	rcu_read_unlock();
295 	return rc;
296 }
297 
298 /* namespace registration */
299 static int __net_init mctp_neigh_net_init(struct net *net)
300 {
301 	struct netns_mctp *ns = &net->mctp;
302 
303 	INIT_LIST_HEAD(&ns->neighbours);
304 	mutex_init(&ns->neigh_lock);
305 	return 0;
306 }
307 
308 static void __net_exit mctp_neigh_net_exit(struct net *net)
309 {
310 	struct netns_mctp *ns = &net->mctp;
311 	struct mctp_neigh *neigh;
312 
313 	list_for_each_entry(neigh, &ns->neighbours, list)
314 		call_rcu(&neigh->rcu, __mctp_neigh_free);
315 }
316 
317 /* net namespace implementation */
318 
319 static struct pernet_operations mctp_net_ops = {
320 	.init = mctp_neigh_net_init,
321 	.exit = mctp_neigh_net_exit,
322 };
323 
324 int __init mctp_neigh_init(void)
325 {
326 	rtnl_register_module(THIS_MODULE, PF_MCTP, RTM_NEWNEIGH,
327 			     mctp_rtm_newneigh, NULL, 0);
328 	rtnl_register_module(THIS_MODULE, PF_MCTP, RTM_DELNEIGH,
329 			     mctp_rtm_delneigh, NULL, 0);
330 	rtnl_register_module(THIS_MODULE, PF_MCTP, RTM_GETNEIGH,
331 			     NULL, mctp_rtm_getneigh, 0);
332 
333 	return register_pernet_subsys(&mctp_net_ops);
334 }
335 
336 void __exit mctp_neigh_exit(void)
337 {
338 	unregister_pernet_subsys(&mctp_net_ops);
339 	rtnl_unregister(PF_MCTP, RTM_GETNEIGH);
340 	rtnl_unregister(PF_MCTP, RTM_DELNEIGH);
341 	rtnl_unregister(PF_MCTP, RTM_NEWNEIGH);
342 }
343