1 // SPDX-License-Identifier: GPL-2.0 2 #include "comm.h" 3 #include <errno.h> 4 #include <string.h> 5 #include <internal/rc_check.h> 6 #include <linux/refcount.h> 7 #include <linux/zalloc.h> 8 #include "rwsem.h" 9 10 DECLARE_RC_STRUCT(comm_str) { 11 refcount_t refcnt; 12 char str[]; 13 }; 14 15 static struct comm_strs { 16 struct rw_semaphore lock; 17 struct comm_str **strs; 18 int num_strs; 19 int capacity; 20 } _comm_strs; 21 22 static void comm_strs__remove_if_last(struct comm_str *cs); 23 24 static void comm_strs__init(void) 25 { 26 init_rwsem(&_comm_strs.lock); 27 _comm_strs.capacity = 16; 28 _comm_strs.num_strs = 0; 29 _comm_strs.strs = calloc(16, sizeof(*_comm_strs.strs)); 30 } 31 32 static struct comm_strs *comm_strs__get(void) 33 { 34 static pthread_once_t comm_strs_type_once = PTHREAD_ONCE_INIT; 35 36 pthread_once(&comm_strs_type_once, comm_strs__init); 37 38 return &_comm_strs; 39 } 40 41 static refcount_t *comm_str__refcnt(struct comm_str *cs) 42 { 43 return &RC_CHK_ACCESS(cs)->refcnt; 44 } 45 46 static const char *comm_str__str(const struct comm_str *cs) 47 { 48 return &RC_CHK_ACCESS(cs)->str[0]; 49 } 50 51 static struct comm_str *comm_str__get(struct comm_str *cs) 52 { 53 struct comm_str *result; 54 55 if (RC_CHK_GET(result, cs)) 56 refcount_inc_not_zero(comm_str__refcnt(cs)); 57 58 return result; 59 } 60 61 static void comm_str__put(struct comm_str *cs) 62 { 63 if (!cs) 64 return; 65 66 if (refcount_dec_and_test(comm_str__refcnt(cs))) { 67 RC_CHK_FREE(cs); 68 } else { 69 if (refcount_read(comm_str__refcnt(cs)) == 1) 70 comm_strs__remove_if_last(cs); 71 72 RC_CHK_PUT(cs); 73 } 74 } 75 76 static struct comm_str *comm_str__new(const char *str) 77 { 78 struct comm_str *result = NULL; 79 RC_STRUCT(comm_str) *cs; 80 81 cs = malloc(sizeof(*cs) + strlen(str) + 1); 82 if (ADD_RC_CHK(result, cs)) { 83 refcount_set(comm_str__refcnt(result), 1); 84 strcpy(&cs->str[0], str); 85 } 86 return result; 87 } 88 89 static int comm_str__cmp(const void *_lhs, const void *_rhs) 90 { 91 const struct comm_str *lhs = *(const struct comm_str * const *)_lhs; 92 const struct comm_str *rhs = *(const struct comm_str * const *)_rhs; 93 94 return strcmp(comm_str__str(lhs), comm_str__str(rhs)); 95 } 96 97 static int comm_str__search(const void *_key, const void *_member) 98 { 99 const char *key = _key; 100 const struct comm_str *member = *(const struct comm_str * const *)_member; 101 102 return strcmp(key, comm_str__str(member)); 103 } 104 105 static void comm_strs__remove_if_last(struct comm_str *cs) 106 { 107 struct comm_strs *comm_strs = comm_strs__get(); 108 109 down_write(&comm_strs->lock); 110 /* 111 * Are there only references from the array, if so remove the array 112 * reference under the write lock so that we don't race with findnew. 113 */ 114 if (refcount_read(comm_str__refcnt(cs)) == 1) { 115 struct comm_str **entry; 116 117 entry = bsearch(comm_str__str(cs), comm_strs->strs, comm_strs->num_strs, 118 sizeof(struct comm_str *), comm_str__search); 119 comm_str__put(*entry); 120 for (int i = entry - comm_strs->strs; i < comm_strs->num_strs - 1; i++) 121 comm_strs->strs[i] = comm_strs->strs[i + 1]; 122 comm_strs->num_strs--; 123 } 124 up_write(&comm_strs->lock); 125 } 126 127 static struct comm_str *__comm_strs__find(struct comm_strs *comm_strs, const char *str) 128 { 129 struct comm_str **result; 130 131 result = bsearch(str, comm_strs->strs, comm_strs->num_strs, sizeof(struct comm_str *), 132 comm_str__search); 133 134 if (!result) 135 return NULL; 136 137 return comm_str__get(*result); 138 } 139 140 static struct comm_str *comm_strs__findnew(const char *str) 141 { 142 struct comm_strs *comm_strs = comm_strs__get(); 143 struct comm_str *result; 144 145 if (!comm_strs) 146 return NULL; 147 148 down_read(&comm_strs->lock); 149 result = __comm_strs__find(comm_strs, str); 150 up_read(&comm_strs->lock); 151 if (result) 152 return result; 153 154 down_write(&comm_strs->lock); 155 result = __comm_strs__find(comm_strs, str); 156 if (!result) { 157 if (comm_strs->num_strs == comm_strs->capacity) { 158 struct comm_str **tmp; 159 160 tmp = reallocarray(comm_strs->strs, 161 comm_strs->capacity + 16, 162 sizeof(*comm_strs->strs)); 163 if (!tmp) { 164 up_write(&comm_strs->lock); 165 return NULL; 166 } 167 comm_strs->strs = tmp; 168 comm_strs->capacity += 16; 169 } 170 result = comm_str__new(str); 171 if (result) { 172 comm_strs->strs[comm_strs->num_strs++] = result; 173 qsort(comm_strs->strs, comm_strs->num_strs, sizeof(struct comm_str *), 174 comm_str__cmp); 175 } 176 } 177 up_write(&comm_strs->lock); 178 return comm_str__get(result); 179 } 180 181 struct comm *comm__new(const char *str, u64 timestamp, bool exec) 182 { 183 struct comm *comm = zalloc(sizeof(*comm)); 184 185 if (!comm) 186 return NULL; 187 188 comm->start = timestamp; 189 comm->exec = exec; 190 191 comm->comm_str = comm_strs__findnew(str); 192 if (!comm->comm_str) { 193 free(comm); 194 return NULL; 195 } 196 197 return comm; 198 } 199 200 int comm__override(struct comm *comm, const char *str, u64 timestamp, bool exec) 201 { 202 struct comm_str *new, *old = comm->comm_str; 203 204 new = comm_strs__findnew(str); 205 if (!new) 206 return -ENOMEM; 207 208 comm_str__put(old); 209 comm->comm_str = new; 210 comm->start = timestamp; 211 if (exec) 212 comm->exec = true; 213 214 return 0; 215 } 216 217 void comm__free(struct comm *comm) 218 { 219 comm_str__put(comm->comm_str); 220 free(comm); 221 } 222 223 const char *comm__str(const struct comm *comm) 224 { 225 return comm_str__str(comm->comm_str); 226 } 227