1 // SPDX-License-Identifier: GPL-2.0 2 #include "comm.h" 3 #include <errno.h> 4 #include <string.h> 5 #include <internal/rc_check.h> 6 #include <linux/refcount.h> 7 #include <linux/zalloc.h> 8 #include "rwsem.h" 9 10 DECLARE_RC_STRUCT(comm_str) { 11 refcount_t refcnt; 12 char str[]; 13 }; 14 15 static struct comm_strs { 16 struct rw_semaphore lock; 17 struct comm_str **strs; 18 int num_strs; 19 int capacity; 20 } _comm_strs; 21 22 static void comm_strs__remove_if_last(struct comm_str *cs); 23 24 static void comm_strs__init(void) 25 { 26 init_rwsem(&_comm_strs.lock); 27 _comm_strs.capacity = 16; 28 _comm_strs.num_strs = 0; 29 _comm_strs.strs = calloc(16, sizeof(*_comm_strs.strs)); 30 } 31 32 static struct comm_strs *comm_strs__get(void) 33 { 34 static pthread_once_t comm_strs_type_once = PTHREAD_ONCE_INIT; 35 36 pthread_once(&comm_strs_type_once, comm_strs__init); 37 38 return &_comm_strs; 39 } 40 41 static refcount_t *comm_str__refcnt(struct comm_str *cs) 42 { 43 return &RC_CHK_ACCESS(cs)->refcnt; 44 } 45 46 static const char *comm_str__str(const struct comm_str *cs) 47 { 48 return &RC_CHK_ACCESS(cs)->str[0]; 49 } 50 51 static struct comm_str *comm_str__get(struct comm_str *cs) 52 { 53 struct comm_str *result; 54 55 if (RC_CHK_GET(result, cs)) 56 refcount_inc_not_zero(comm_str__refcnt(cs)); 57 58 return result; 59 } 60 61 static void comm_str__put(struct comm_str *cs) 62 { 63 if (!cs) 64 return; 65 66 if (refcount_dec_and_test(comm_str__refcnt(cs))) { 67 RC_CHK_FREE(cs); 68 } else { 69 if (refcount_read(comm_str__refcnt(cs)) == 1) 70 comm_strs__remove_if_last(cs); 71 72 RC_CHK_PUT(cs); 73 } 74 } 75 76 static struct comm_str *comm_str__new(const char *str) 77 { 78 struct comm_str *result = NULL; 79 RC_STRUCT(comm_str) *cs; 80 81 cs = malloc(sizeof(*cs) + strlen(str) + 1); 82 if (ADD_RC_CHK(result, cs)) { 83 refcount_set(comm_str__refcnt(result), 1); 84 strcpy(&cs->str[0], str); 85 } 86 return result; 87 } 88 89 static int comm_str__search(const void *_key, const void *_member) 90 { 91 const char *key = _key; 92 const struct comm_str *member = *(const struct comm_str * const *)_member; 93 94 return strcmp(key, comm_str__str(member)); 95 } 96 97 static void comm_strs__remove_if_last(struct comm_str *cs) 98 { 99 struct comm_strs *comm_strs = comm_strs__get(); 100 101 down_write(&comm_strs->lock); 102 /* 103 * Are there only references from the array, if so remove the array 104 * reference under the write lock so that we don't race with findnew. 105 */ 106 if (refcount_read(comm_str__refcnt(cs)) == 1) { 107 struct comm_str **entry; 108 109 entry = bsearch(comm_str__str(cs), comm_strs->strs, comm_strs->num_strs, 110 sizeof(struct comm_str *), comm_str__search); 111 comm_str__put(*entry); 112 for (int i = entry - comm_strs->strs; i < comm_strs->num_strs - 1; i++) 113 comm_strs->strs[i] = comm_strs->strs[i + 1]; 114 comm_strs->num_strs--; 115 } 116 up_write(&comm_strs->lock); 117 } 118 119 static struct comm_str *__comm_strs__find(struct comm_strs *comm_strs, const char *str) 120 { 121 struct comm_str **result; 122 123 result = bsearch(str, comm_strs->strs, comm_strs->num_strs, sizeof(struct comm_str *), 124 comm_str__search); 125 126 if (!result) 127 return NULL; 128 129 return comm_str__get(*result); 130 } 131 132 static struct comm_str *comm_strs__findnew(const char *str) 133 { 134 struct comm_strs *comm_strs = comm_strs__get(); 135 struct comm_str *result; 136 137 if (!comm_strs) 138 return NULL; 139 140 down_read(&comm_strs->lock); 141 result = __comm_strs__find(comm_strs, str); 142 up_read(&comm_strs->lock); 143 if (result) 144 return result; 145 146 down_write(&comm_strs->lock); 147 result = __comm_strs__find(comm_strs, str); 148 if (!result) { 149 if (comm_strs->num_strs == comm_strs->capacity) { 150 struct comm_str **tmp; 151 152 tmp = reallocarray(comm_strs->strs, 153 comm_strs->capacity + 16, 154 sizeof(*comm_strs->strs)); 155 if (!tmp) { 156 up_write(&comm_strs->lock); 157 return NULL; 158 } 159 comm_strs->strs = tmp; 160 comm_strs->capacity += 16; 161 } 162 result = comm_str__new(str); 163 if (result) { 164 int low = 0, high = comm_strs->num_strs - 1; 165 int insert = comm_strs->num_strs; /* Default to inserting at the end. */ 166 167 while (low <= high) { 168 int mid = low + (high - low) / 2; 169 int cmp = strcmp(comm_str__str(comm_strs->strs[mid]), str); 170 171 if (cmp < 0) { 172 low = mid + 1; 173 } else { 174 high = mid - 1; 175 insert = mid; 176 } 177 } 178 memmove(&comm_strs->strs[insert + 1], &comm_strs->strs[insert], 179 (comm_strs->num_strs - insert) * sizeof(struct comm_str *)); 180 comm_strs->num_strs++; 181 comm_strs->strs[insert] = result; 182 } 183 } 184 up_write(&comm_strs->lock); 185 return comm_str__get(result); 186 } 187 188 struct comm *comm__new(const char *str, u64 timestamp, bool exec) 189 { 190 struct comm *comm = zalloc(sizeof(*comm)); 191 192 if (!comm) 193 return NULL; 194 195 comm->start = timestamp; 196 comm->exec = exec; 197 198 comm->comm_str = comm_strs__findnew(str); 199 if (!comm->comm_str) { 200 free(comm); 201 return NULL; 202 } 203 204 return comm; 205 } 206 207 int comm__override(struct comm *comm, const char *str, u64 timestamp, bool exec) 208 { 209 struct comm_str *new, *old = comm->comm_str; 210 211 new = comm_strs__findnew(str); 212 if (!new) 213 return -ENOMEM; 214 215 comm_str__put(old); 216 comm->comm_str = new; 217 comm->start = timestamp; 218 if (exec) 219 comm->exec = true; 220 221 return 0; 222 } 223 224 void comm__free(struct comm *comm) 225 { 226 comm_str__put(comm->comm_str); 227 free(comm); 228 } 229 230 const char *comm__str(const struct comm *comm) 231 { 232 return comm_str__str(comm->comm_str); 233 } 234