xref: /linux/tools/perf/util/comm.c (revision 2330437da0994321020777c605a2a8cb0ecb7001)
1 // SPDX-License-Identifier: GPL-2.0
2 #include "comm.h"
3 #include <errno.h>
4 #include <string.h>
5 #include <internal/rc_check.h>
6 #include <linux/refcount.h>
7 #include <linux/zalloc.h>
8 #include <tools/libc_compat.h> // reallocarray
9 
10 #include "rwsem.h"
11 
12 DECLARE_RC_STRUCT(comm_str) {
13 	refcount_t refcnt;
14 	char str[];
15 };
16 
17 static struct comm_strs {
18 	struct rw_semaphore lock;
19 	struct comm_str **strs;
20 	int num_strs;
21 	int capacity;
22 } _comm_strs;
23 
24 static void comm_strs__remove_if_last(struct comm_str *cs);
25 
26 static void comm_strs__init(void)
27 	NO_THREAD_SAFETY_ANALYSIS /* Inherently single threaded due to pthread_once. */
28 {
29 	init_rwsem(&_comm_strs.lock);
30 	_comm_strs.capacity = 16;
31 	_comm_strs.num_strs = 0;
32 	_comm_strs.strs = calloc(16, sizeof(*_comm_strs.strs));
33 }
34 
35 static struct comm_strs *comm_strs__get(void)
36 {
37 	static pthread_once_t comm_strs_type_once = PTHREAD_ONCE_INIT;
38 
39 	pthread_once(&comm_strs_type_once, comm_strs__init);
40 
41 	return &_comm_strs;
42 }
43 
44 static refcount_t *comm_str__refcnt(struct comm_str *cs)
45 {
46 	return &RC_CHK_ACCESS(cs)->refcnt;
47 }
48 
49 static const char *comm_str__str(const struct comm_str *cs)
50 {
51 	return &RC_CHK_ACCESS(cs)->str[0];
52 }
53 
54 static struct comm_str *comm_str__get(struct comm_str *cs)
55 {
56 	struct comm_str *result;
57 
58 	if (RC_CHK_GET(result, cs))
59 		refcount_inc_not_zero(comm_str__refcnt(cs));
60 
61 	return result;
62 }
63 
64 static void comm_str__put(struct comm_str *cs)
65 {
66 	if (!cs)
67 		return;
68 
69 	if (refcount_dec_and_test(comm_str__refcnt(cs))) {
70 		RC_CHK_FREE(cs);
71 	} else {
72 		if (refcount_read(comm_str__refcnt(cs)) == 1)
73 			comm_strs__remove_if_last(cs);
74 
75 		RC_CHK_PUT(cs);
76 	}
77 }
78 
79 static struct comm_str *comm_str__new(const char *str)
80 {
81 	struct comm_str *result = NULL;
82 	RC_STRUCT(comm_str) *cs;
83 
84 	cs = malloc(sizeof(*cs) + strlen(str) + 1);
85 	if (ADD_RC_CHK(result, cs)) {
86 		refcount_set(comm_str__refcnt(result), 1);
87 		strcpy(&cs->str[0], str);
88 	}
89 	return result;
90 }
91 
92 static int comm_str__search(const void *_key, const void *_member)
93 {
94 	const char *key = _key;
95 	const struct comm_str *member = *(const struct comm_str * const *)_member;
96 
97 	return strcmp(key, comm_str__str(member));
98 }
99 
100 static void comm_strs__remove_if_last(struct comm_str *cs)
101 {
102 	struct comm_strs *comm_strs = comm_strs__get();
103 
104 	down_write(&comm_strs->lock);
105 	/*
106 	 * Are there only references from the array, if so remove the array
107 	 * reference under the write lock so that we don't race with findnew.
108 	 */
109 	if (refcount_read(comm_str__refcnt(cs)) == 1) {
110 		struct comm_str **entry;
111 
112 		entry = bsearch(comm_str__str(cs), comm_strs->strs, comm_strs->num_strs,
113 				sizeof(struct comm_str *), comm_str__search);
114 		comm_str__put(*entry);
115 		for (int i = entry - comm_strs->strs; i < comm_strs->num_strs - 1; i++)
116 			comm_strs->strs[i] = comm_strs->strs[i + 1];
117 		comm_strs->num_strs--;
118 	}
119 	up_write(&comm_strs->lock);
120 }
121 
122 static struct comm_str *__comm_strs__find(struct comm_strs *comm_strs, const char *str)
123 	SHARED_LOCKS_REQUIRED(comm_strs->lock)
124 {
125 	struct comm_str **result;
126 
127 	result = bsearch(str, comm_strs->strs, comm_strs->num_strs, sizeof(struct comm_str *),
128 			 comm_str__search);
129 
130 	if (!result)
131 		return NULL;
132 
133 	return comm_str__get(*result);
134 }
135 
136 static struct comm_str *comm_strs__findnew(const char *str)
137 {
138 	struct comm_strs *comm_strs = comm_strs__get();
139 	struct comm_str *result;
140 
141 	if (!comm_strs)
142 		return NULL;
143 
144 	down_read(&comm_strs->lock);
145 	result = __comm_strs__find(comm_strs, str);
146 	up_read(&comm_strs->lock);
147 	if (result)
148 		return result;
149 
150 	down_write(&comm_strs->lock);
151 	result = __comm_strs__find(comm_strs, str);
152 	if (!result) {
153 		if (comm_strs->num_strs == comm_strs->capacity) {
154 			struct comm_str **tmp;
155 
156 			tmp = reallocarray(comm_strs->strs,
157 					   comm_strs->capacity + 16,
158 					   sizeof(*comm_strs->strs));
159 			if (!tmp) {
160 				up_write(&comm_strs->lock);
161 				return NULL;
162 			}
163 			comm_strs->strs = tmp;
164 			comm_strs->capacity += 16;
165 		}
166 		result = comm_str__new(str);
167 		if (result) {
168 			int low = 0, high = comm_strs->num_strs - 1;
169 			int insert = comm_strs->num_strs; /* Default to inserting at the end. */
170 
171 			while (low <= high) {
172 				int mid = low + (high - low) / 2;
173 				int cmp = strcmp(comm_str__str(comm_strs->strs[mid]), str);
174 
175 				if (cmp < 0) {
176 					low = mid + 1;
177 				} else {
178 					high = mid - 1;
179 					insert = mid;
180 				}
181 			}
182 			memmove(&comm_strs->strs[insert + 1], &comm_strs->strs[insert],
183 				(comm_strs->num_strs - insert) * sizeof(struct comm_str *));
184 			comm_strs->num_strs++;
185 			comm_strs->strs[insert] = result;
186 		}
187 	}
188 	up_write(&comm_strs->lock);
189 	return comm_str__get(result);
190 }
191 
192 struct comm *comm__new(const char *str, u64 timestamp, bool exec)
193 {
194 	struct comm *comm = zalloc(sizeof(*comm));
195 
196 	if (!comm)
197 		return NULL;
198 
199 	comm->start = timestamp;
200 	comm->exec = exec;
201 
202 	comm->comm_str = comm_strs__findnew(str);
203 	if (!comm->comm_str) {
204 		free(comm);
205 		return NULL;
206 	}
207 
208 	return comm;
209 }
210 
211 int comm__override(struct comm *comm, const char *str, u64 timestamp, bool exec)
212 {
213 	struct comm_str *new, *old = comm->comm_str;
214 
215 	new = comm_strs__findnew(str);
216 	if (!new)
217 		return -ENOMEM;
218 
219 	comm_str__put(old);
220 	comm->comm_str = new;
221 	comm->start = timestamp;
222 	if (exec)
223 		comm->exec = true;
224 
225 	return 0;
226 }
227 
228 void comm__free(struct comm *comm)
229 {
230 	comm_str__put(comm->comm_str);
231 	free(comm);
232 }
233 
234 const char *comm__str(const struct comm *comm)
235 {
236 	return comm_str__str(comm->comm_str);
237 }
238