1 // SPDX-License-Identifier: GPL-2.0 2 #include "comm.h" 3 #include <errno.h> 4 #include <string.h> 5 #include <internal/rc_check.h> 6 #include <linux/refcount.h> 7 #include <linux/zalloc.h> 8 #include <tools/libc_compat.h> // reallocarray 9 10 #include "rwsem.h" 11 12 DECLARE_RC_STRUCT(comm_str) { 13 refcount_t refcnt; 14 char str[]; 15 }; 16 17 static struct comm_strs { 18 struct rw_semaphore lock; 19 struct comm_str **strs; 20 int num_strs; 21 int capacity; 22 } _comm_strs; 23 24 static void comm_strs__remove_if_last(struct comm_str *cs); 25 26 static void comm_strs__init(void) 27 NO_THREAD_SAFETY_ANALYSIS /* Inherently single threaded due to pthread_once. */ 28 { 29 init_rwsem(&_comm_strs.lock); 30 _comm_strs.capacity = 16; 31 _comm_strs.num_strs = 0; 32 _comm_strs.strs = calloc(16, sizeof(*_comm_strs.strs)); 33 } 34 35 static struct comm_strs *comm_strs__get(void) 36 { 37 static pthread_once_t comm_strs_type_once = PTHREAD_ONCE_INIT; 38 39 pthread_once(&comm_strs_type_once, comm_strs__init); 40 41 return &_comm_strs; 42 } 43 44 static refcount_t *comm_str__refcnt(struct comm_str *cs) 45 { 46 return &RC_CHK_ACCESS(cs)->refcnt; 47 } 48 49 static const char *comm_str__str(const struct comm_str *cs) 50 { 51 return &RC_CHK_ACCESS(cs)->str[0]; 52 } 53 54 static struct comm_str *comm_str__get(struct comm_str *cs) 55 { 56 struct comm_str *result; 57 58 if (RC_CHK_GET(result, cs)) 59 refcount_inc_not_zero(comm_str__refcnt(cs)); 60 61 return result; 62 } 63 64 static void comm_str__put(struct comm_str *cs) 65 { 66 if (!cs) 67 return; 68 69 if (refcount_dec_and_test(comm_str__refcnt(cs))) { 70 RC_CHK_FREE(cs); 71 } else { 72 if (refcount_read(comm_str__refcnt(cs)) == 1) 73 comm_strs__remove_if_last(cs); 74 75 RC_CHK_PUT(cs); 76 } 77 } 78 79 static struct comm_str *comm_str__new(const char *str) 80 { 81 struct comm_str *result = NULL; 82 RC_STRUCT(comm_str) *cs; 83 84 cs = malloc(sizeof(*cs) + strlen(str) + 1); 85 if (ADD_RC_CHK(result, cs)) { 86 refcount_set(comm_str__refcnt(result), 1); 87 strcpy(&cs->str[0], str); 88 } 89 return result; 90 } 91 92 static int comm_str__search(const void *_key, const void *_member) 93 { 94 const char *key = _key; 95 const struct comm_str *member = *(const struct comm_str * const *)_member; 96 97 return strcmp(key, comm_str__str(member)); 98 } 99 100 static void comm_strs__remove_if_last(struct comm_str *cs) 101 { 102 struct comm_strs *comm_strs = comm_strs__get(); 103 104 down_write(&comm_strs->lock); 105 /* 106 * Are there only references from the array, if so remove the array 107 * reference under the write lock so that we don't race with findnew. 108 */ 109 if (refcount_read(comm_str__refcnt(cs)) == 1) { 110 struct comm_str **entry; 111 112 entry = bsearch(comm_str__str(cs), comm_strs->strs, comm_strs->num_strs, 113 sizeof(struct comm_str *), comm_str__search); 114 comm_str__put(*entry); 115 for (int i = entry - comm_strs->strs; i < comm_strs->num_strs - 1; i++) 116 comm_strs->strs[i] = comm_strs->strs[i + 1]; 117 comm_strs->num_strs--; 118 } 119 up_write(&comm_strs->lock); 120 } 121 122 static struct comm_str *__comm_strs__find(struct comm_strs *comm_strs, const char *str) 123 SHARED_LOCKS_REQUIRED(comm_strs->lock) 124 { 125 struct comm_str **result; 126 127 result = bsearch(str, comm_strs->strs, comm_strs->num_strs, sizeof(struct comm_str *), 128 comm_str__search); 129 130 if (!result) 131 return NULL; 132 133 return comm_str__get(*result); 134 } 135 136 static struct comm_str *comm_strs__findnew(const char *str) 137 { 138 struct comm_strs *comm_strs = comm_strs__get(); 139 struct comm_str *result; 140 141 if (!comm_strs) 142 return NULL; 143 144 down_read(&comm_strs->lock); 145 result = __comm_strs__find(comm_strs, str); 146 up_read(&comm_strs->lock); 147 if (result) 148 return result; 149 150 down_write(&comm_strs->lock); 151 result = __comm_strs__find(comm_strs, str); 152 if (!result) { 153 if (comm_strs->num_strs == comm_strs->capacity) { 154 struct comm_str **tmp; 155 156 tmp = reallocarray(comm_strs->strs, 157 comm_strs->capacity + 16, 158 sizeof(*comm_strs->strs)); 159 if (!tmp) { 160 up_write(&comm_strs->lock); 161 return NULL; 162 } 163 comm_strs->strs = tmp; 164 comm_strs->capacity += 16; 165 } 166 result = comm_str__new(str); 167 if (result) { 168 int low = 0, high = comm_strs->num_strs - 1; 169 int insert = comm_strs->num_strs; /* Default to inserting at the end. */ 170 171 while (low <= high) { 172 int mid = low + (high - low) / 2; 173 int cmp = strcmp(comm_str__str(comm_strs->strs[mid]), str); 174 175 if (cmp < 0) { 176 low = mid + 1; 177 } else { 178 high = mid - 1; 179 insert = mid; 180 } 181 } 182 memmove(&comm_strs->strs[insert + 1], &comm_strs->strs[insert], 183 (comm_strs->num_strs - insert) * sizeof(struct comm_str *)); 184 comm_strs->num_strs++; 185 comm_strs->strs[insert] = result; 186 } 187 } 188 up_write(&comm_strs->lock); 189 return comm_str__get(result); 190 } 191 192 struct comm *comm__new(const char *str, u64 timestamp, bool exec) 193 { 194 struct comm *comm = zalloc(sizeof(*comm)); 195 196 if (!comm) 197 return NULL; 198 199 comm->start = timestamp; 200 comm->exec = exec; 201 202 comm->comm_str = comm_strs__findnew(str); 203 if (!comm->comm_str) { 204 free(comm); 205 return NULL; 206 } 207 208 return comm; 209 } 210 211 int comm__override(struct comm *comm, const char *str, u64 timestamp, bool exec) 212 { 213 struct comm_str *new, *old = comm->comm_str; 214 215 new = comm_strs__findnew(str); 216 if (!new) 217 return -ENOMEM; 218 219 comm_str__put(old); 220 comm->comm_str = new; 221 comm->start = timestamp; 222 if (exec) 223 comm->exec = true; 224 225 return 0; 226 } 227 228 void comm__free(struct comm *comm) 229 { 230 comm_str__put(comm->comm_str); 231 free(comm); 232 } 233 234 const char *comm__str(const struct comm *comm) 235 { 236 return comm_str__str(comm->comm_str); 237 } 238