xref: /linux/net/mctp/neigh.c (revision c717993dd76a1049093af5c262e751d901b8da10)
1  // SPDX-License-Identifier: GPL-2.0
2  /*
3   * Management Component Transport Protocol (MCTP) - routing
4   * implementation.
5   *
6   * This is currently based on a simple routing table, with no dst cache. The
7   * number of routes should stay fairly small, so the lookup cost is small.
8   *
9   * Copyright (c) 2021 Code Construct
10   * Copyright (c) 2021 Google
11   */
12  
13  #include <linux/idr.h>
14  #include <linux/mctp.h>
15  #include <linux/netdevice.h>
16  #include <linux/rtnetlink.h>
17  #include <linux/skbuff.h>
18  
19  #include <net/mctp.h>
20  #include <net/mctpdevice.h>
21  #include <net/netlink.h>
22  #include <net/sock.h>
23  
24  static int mctp_neigh_add(struct mctp_dev *mdev, mctp_eid_t eid,
25  			  enum mctp_neigh_source source,
26  			  size_t lladdr_len, const void *lladdr)
27  {
28  	struct net *net = dev_net(mdev->dev);
29  	struct mctp_neigh *neigh;
30  	int rc;
31  
32  	mutex_lock(&net->mctp.neigh_lock);
33  	if (mctp_neigh_lookup(mdev, eid, NULL) == 0) {
34  		rc = -EEXIST;
35  		goto out;
36  	}
37  
38  	if (lladdr_len > sizeof(neigh->ha)) {
39  		rc = -EINVAL;
40  		goto out;
41  	}
42  
43  	neigh = kzalloc(sizeof(*neigh), GFP_KERNEL);
44  	if (!neigh) {
45  		rc = -ENOMEM;
46  		goto out;
47  	}
48  	INIT_LIST_HEAD(&neigh->list);
49  	neigh->dev = mdev;
50  	mctp_dev_hold(neigh->dev);
51  	neigh->eid = eid;
52  	neigh->source = source;
53  	memcpy(neigh->ha, lladdr, lladdr_len);
54  
55  	list_add_rcu(&neigh->list, &net->mctp.neighbours);
56  	rc = 0;
57  out:
58  	mutex_unlock(&net->mctp.neigh_lock);
59  	return rc;
60  }
61  
62  static void __mctp_neigh_free(struct rcu_head *rcu)
63  {
64  	struct mctp_neigh *neigh = container_of(rcu, struct mctp_neigh, rcu);
65  
66  	mctp_dev_put(neigh->dev);
67  	kfree(neigh);
68  }
69  
70  /* Removes all neighbour entries referring to a device */
71  void mctp_neigh_remove_dev(struct mctp_dev *mdev)
72  {
73  	struct net *net = dev_net(mdev->dev);
74  	struct mctp_neigh *neigh, *tmp;
75  
76  	mutex_lock(&net->mctp.neigh_lock);
77  	list_for_each_entry_safe(neigh, tmp, &net->mctp.neighbours, list) {
78  		if (neigh->dev == mdev) {
79  			list_del_rcu(&neigh->list);
80  			/* TODO: immediate RTM_DELNEIGH */
81  			call_rcu(&neigh->rcu, __mctp_neigh_free);
82  		}
83  	}
84  
85  	mutex_unlock(&net->mctp.neigh_lock);
86  }
87  
88  static int mctp_neigh_remove(struct mctp_dev *mdev, mctp_eid_t eid,
89  			     enum mctp_neigh_source source)
90  {
91  	struct net *net = dev_net(mdev->dev);
92  	struct mctp_neigh *neigh, *tmp;
93  	bool dropped = false;
94  
95  	mutex_lock(&net->mctp.neigh_lock);
96  	list_for_each_entry_safe(neigh, tmp, &net->mctp.neighbours, list) {
97  		if (neigh->dev == mdev && neigh->eid == eid &&
98  		    neigh->source == source) {
99  			list_del_rcu(&neigh->list);
100  			/* TODO: immediate RTM_DELNEIGH */
101  			call_rcu(&neigh->rcu, __mctp_neigh_free);
102  			dropped = true;
103  		}
104  	}
105  
106  	mutex_unlock(&net->mctp.neigh_lock);
107  	return dropped ? 0 : -ENOENT;
108  }
109  
110  static const struct nla_policy nd_mctp_policy[NDA_MAX + 1] = {
111  	[NDA_DST]		= { .type = NLA_U8 },
112  	[NDA_LLADDR]		= { .type = NLA_BINARY, .len = MAX_ADDR_LEN },
113  };
114  
115  static int mctp_rtm_newneigh(struct sk_buff *skb, struct nlmsghdr *nlh,
116  			     struct netlink_ext_ack *extack)
117  {
118  	struct net *net = sock_net(skb->sk);
119  	struct net_device *dev;
120  	struct mctp_dev *mdev;
121  	struct ndmsg *ndm;
122  	struct nlattr *tb[NDA_MAX + 1];
123  	int rc;
124  	mctp_eid_t eid;
125  	void *lladdr;
126  	int lladdr_len;
127  
128  	rc = nlmsg_parse(nlh, sizeof(*ndm), tb, NDA_MAX, nd_mctp_policy,
129  			 extack);
130  	if (rc < 0) {
131  		NL_SET_ERR_MSG(extack, "lladdr too large?");
132  		return rc;
133  	}
134  
135  	if (!tb[NDA_DST]) {
136  		NL_SET_ERR_MSG(extack, "Neighbour EID must be specified");
137  		return -EINVAL;
138  	}
139  
140  	if (!tb[NDA_LLADDR]) {
141  		NL_SET_ERR_MSG(extack, "Neighbour lladdr must be specified");
142  		return -EINVAL;
143  	}
144  
145  	eid = nla_get_u8(tb[NDA_DST]);
146  	if (!mctp_address_ok(eid)) {
147  		NL_SET_ERR_MSG(extack, "Invalid neighbour EID");
148  		return -EINVAL;
149  	}
150  
151  	lladdr = nla_data(tb[NDA_LLADDR]);
152  	lladdr_len = nla_len(tb[NDA_LLADDR]);
153  
154  	ndm = nlmsg_data(nlh);
155  
156  	dev = __dev_get_by_index(net, ndm->ndm_ifindex);
157  	if (!dev)
158  		return -ENODEV;
159  
160  	mdev = mctp_dev_get_rtnl(dev);
161  	if (!mdev)
162  		return -ENODEV;
163  
164  	if (lladdr_len != dev->addr_len) {
165  		NL_SET_ERR_MSG(extack, "Wrong lladdr length");
166  		return -EINVAL;
167  	}
168  
169  	return mctp_neigh_add(mdev, eid, MCTP_NEIGH_STATIC,
170  			lladdr_len, lladdr);
171  }
172  
173  static int mctp_rtm_delneigh(struct sk_buff *skb, struct nlmsghdr *nlh,
174  			     struct netlink_ext_ack *extack)
175  {
176  	struct net *net = sock_net(skb->sk);
177  	struct nlattr *tb[NDA_MAX + 1];
178  	struct net_device *dev;
179  	struct mctp_dev *mdev;
180  	struct ndmsg *ndm;
181  	int rc;
182  	mctp_eid_t eid;
183  
184  	rc = nlmsg_parse(nlh, sizeof(*ndm), tb, NDA_MAX, nd_mctp_policy,
185  			 extack);
186  	if (rc < 0) {
187  		NL_SET_ERR_MSG(extack, "incorrect format");
188  		return rc;
189  	}
190  
191  	if (!tb[NDA_DST]) {
192  		NL_SET_ERR_MSG(extack, "Neighbour EID must be specified");
193  		return -EINVAL;
194  	}
195  	eid = nla_get_u8(tb[NDA_DST]);
196  
197  	ndm = nlmsg_data(nlh);
198  	dev = __dev_get_by_index(net, ndm->ndm_ifindex);
199  	if (!dev)
200  		return -ENODEV;
201  
202  	mdev = mctp_dev_get_rtnl(dev);
203  	if (!mdev)
204  		return -ENODEV;
205  
206  	return mctp_neigh_remove(mdev, eid, MCTP_NEIGH_STATIC);
207  }
208  
209  static int mctp_fill_neigh(struct sk_buff *skb, u32 portid, u32 seq, int event,
210  			   unsigned int flags, struct mctp_neigh *neigh)
211  {
212  	struct net_device *dev = neigh->dev->dev;
213  	struct nlmsghdr *nlh;
214  	struct ndmsg *hdr;
215  
216  	nlh = nlmsg_put(skb, portid, seq, event, sizeof(*hdr), flags);
217  	if (!nlh)
218  		return -EMSGSIZE;
219  
220  	hdr = nlmsg_data(nlh);
221  	hdr->ndm_family = AF_MCTP;
222  	hdr->ndm_ifindex = dev->ifindex;
223  	hdr->ndm_state = 0; // TODO other state bits?
224  	if (neigh->source == MCTP_NEIGH_STATIC)
225  		hdr->ndm_state |= NUD_PERMANENT;
226  	hdr->ndm_flags = 0;
227  	hdr->ndm_type = RTN_UNICAST; // TODO: is loopback RTN_LOCAL?
228  
229  	if (nla_put_u8(skb, NDA_DST, neigh->eid))
230  		goto cancel;
231  
232  	if (nla_put(skb, NDA_LLADDR, dev->addr_len, neigh->ha))
233  		goto cancel;
234  
235  	nlmsg_end(skb, nlh);
236  
237  	return 0;
238  cancel:
239  	nlmsg_cancel(skb, nlh);
240  	return -EMSGSIZE;
241  }
242  
243  static int mctp_rtm_getneigh(struct sk_buff *skb, struct netlink_callback *cb)
244  {
245  	struct net *net = sock_net(skb->sk);
246  	int rc, idx, req_ifindex;
247  	struct mctp_neigh *neigh;
248  	struct ndmsg *ndmsg;
249  	struct {
250  		int idx;
251  	} *cbctx = (void *)cb->ctx;
252  
253  	ndmsg = nlmsg_data(cb->nlh);
254  	req_ifindex = ndmsg->ndm_ifindex;
255  
256  	idx = 0;
257  	rcu_read_lock();
258  	list_for_each_entry_rcu(neigh, &net->mctp.neighbours, list) {
259  		if (idx < cbctx->idx)
260  			goto cont;
261  
262  		rc = 0;
263  		if (req_ifindex == 0 || req_ifindex == neigh->dev->dev->ifindex)
264  			rc = mctp_fill_neigh(skb, NETLINK_CB(cb->skb).portid,
265  					     cb->nlh->nlmsg_seq,
266  					     RTM_NEWNEIGH, NLM_F_MULTI, neigh);
267  
268  		if (rc)
269  			break;
270  cont:
271  		idx++;
272  	}
273  	rcu_read_unlock();
274  
275  	cbctx->idx = idx;
276  	return skb->len;
277  }
278  
279  int mctp_neigh_lookup(struct mctp_dev *mdev, mctp_eid_t eid, void *ret_hwaddr)
280  {
281  	struct net *net = dev_net(mdev->dev);
282  	struct mctp_neigh *neigh;
283  	int rc = -EHOSTUNREACH; // TODO: or ENOENT?
284  
285  	rcu_read_lock();
286  	list_for_each_entry_rcu(neigh, &net->mctp.neighbours, list) {
287  		if (mdev == neigh->dev && eid == neigh->eid) {
288  			if (ret_hwaddr)
289  				memcpy(ret_hwaddr, neigh->ha,
290  				       sizeof(neigh->ha));
291  			rc = 0;
292  			break;
293  		}
294  	}
295  	rcu_read_unlock();
296  	return rc;
297  }
298  
299  /* namespace registration */
300  static int __net_init mctp_neigh_net_init(struct net *net)
301  {
302  	struct netns_mctp *ns = &net->mctp;
303  
304  	INIT_LIST_HEAD(&ns->neighbours);
305  	mutex_init(&ns->neigh_lock);
306  	return 0;
307  }
308  
309  static void __net_exit mctp_neigh_net_exit(struct net *net)
310  {
311  	struct netns_mctp *ns = &net->mctp;
312  	struct mctp_neigh *neigh;
313  
314  	list_for_each_entry(neigh, &ns->neighbours, list)
315  		call_rcu(&neigh->rcu, __mctp_neigh_free);
316  }
317  
318  /* net namespace implementation */
319  
320  static struct pernet_operations mctp_net_ops = {
321  	.init = mctp_neigh_net_init,
322  	.exit = mctp_neigh_net_exit,
323  };
324  
325  int __init mctp_neigh_init(void)
326  {
327  	rtnl_register_module(THIS_MODULE, PF_MCTP, RTM_NEWNEIGH,
328  			     mctp_rtm_newneigh, NULL, 0);
329  	rtnl_register_module(THIS_MODULE, PF_MCTP, RTM_DELNEIGH,
330  			     mctp_rtm_delneigh, NULL, 0);
331  	rtnl_register_module(THIS_MODULE, PF_MCTP, RTM_GETNEIGH,
332  			     NULL, mctp_rtm_getneigh, 0);
333  
334  	return register_pernet_subsys(&mctp_net_ops);
335  }
336  
337  void __exit mctp_neigh_exit(void)
338  {
339  	unregister_pernet_subsys(&mctp_net_ops);
340  	rtnl_unregister(PF_MCTP, RTM_GETNEIGH);
341  	rtnl_unregister(PF_MCTP, RTM_DELNEIGH);
342  	rtnl_unregister(PF_MCTP, RTM_NEWNEIGH);
343  }
344