xref: /illumos-gate/usr/src/tools/smatch/src/smatch_data/db/smdb.py (revision 856f710c9dc323b39da5935194d7928ffb99b67f)
1#!/usr/bin/python
2
3# Copyright (C) 2013 Oracle.
4#
5# Licensed under the Open Software License version 1.1
6
7import sqlite3
8import sys
9import re
10
11try:
12    con = sqlite3.connect('smatch_db.sqlite')
13except sqlite3.Error, e:
14    print "Error %s:" % e.args[0]
15    sys.exit(1)
16
17def usage():
18    print "%s" %(sys.argv[0])
19    print "<function> - how a function is called"
20    print "return_states <function> - what a function returns"
21    print "call_tree <function> - show the call tree"
22    print "where <struct_type> <member> - where a struct member is set"
23    print "type_size <struct_type> <member> - how a struct member is allocated"
24    print "data_info <struct_type> <member> - information about a given data type"
25    print "function_ptr <function> - which function pointers point to this"
26    print "trace_param <function> <param> - trace where a parameter came from"
27    print "locals <file> - print the local values in a file."
28    sys.exit(1)
29
30function_ptrs = []
31searched_ptrs = []
32def get_function_pointers_helper(func):
33    cur = con.cursor()
34    cur.execute("select distinct ptr from function_ptr where function = '%s';" %(func))
35    for row in cur:
36        ptr = row[0]
37        if ptr in function_ptrs:
38            continue
39        function_ptrs.append(ptr)
40        if not ptr in searched_ptrs:
41            searched_ptrs.append(ptr)
42            get_function_pointers_helper(ptr)
43
44def get_function_pointers(func):
45    global function_ptrs
46    global searched_ptrs
47    function_ptrs = [func]
48    searched_ptrs = [func]
49    get_function_pointers_helper(func)
50    return function_ptrs
51
52db_types = {   0: "INTERNAL",
53             101: "PARAM_CLEARED",
54             103: "PARAM_LIMIT",
55             104: "PARAM_FILTER",
56            1001: "PARAM_VALUE",
57            1002: "BUF_SIZE",
58            1003: "USER_DATA",
59            1004: "CAPPED_DATA",
60            1005: "RETURN_VALUE",
61            1006: "DEREFERENCE",
62            1007: "RANGE_CAP",
63            1008: "LOCK_HELD",
64            1009: "LOCK_RELEASED",
65            1010: "ABSOLUTE_LIMITS",
66            1012: "PARAM_ADD",
67            1013: "PARAM_FREED",
68            1014: "DATA_SOURCE",
69            1015: "FUZZY_MAX",
70            1016: "STR_LEN",
71            1017: "ARRAY_LEN",
72            1018: "CAPABLE",
73            1019: "NS_CAPABLE",
74            1022: "TYPE_LINK",
75            1023: "UNTRACKED_PARAM",
76            1024: "CULL_PATH",
77            1025: "PARAM_SET",
78            1026: "PARAM_USED",
79            1027: "BYTE_UNITS",
80            1028: "COMPARE_LIMIT",
81            1029: "PARAM_COMPARE",
82            8017: "USER_DATA2",
83            8018: "NO_OVERFLOW",
84            8019: "NO_OVERFLOW_SIMPLE",
85            8020: "LOCKED",
86            8021: "UNLOCKED",
87            8023: "ATOMIC_INC",
88            8024: "ATOMIC_DEC",
89};
90
91def add_range(rl, min_val, max_val):
92    check_next = 0
93    done = 0
94    ret = []
95    idx = 0
96
97    if len(rl) == 0:
98        return [[min_val, max_val]]
99
100    for idx in range(len(rl)):
101        cur_min = rl[idx][0]
102        cur_max = rl[idx][1]
103
104        # we already merged the new range but we might need to change later
105        # ranges if they over lap with more than one
106        if check_next:
107            # join with added range
108            if max_val + 1 == cur_min:
109                ret[len(ret) - 1][1] = cur_max
110                done = 1
111                break
112            # don't overlap
113            if max_val < cur_min:
114                ret.append([cur_min, cur_max])
115                done = 1
116                break
117            # partially overlap
118            if max_val < cur_max:
119                ret[len(ret) - 1][1] = cur_max
120                done = 1
121                break
122            # completely overlap
123            continue
124
125        # join 2 ranges into one
126        if max_val + 1 == cur_min:
127            ret.append([min_val, cur_max])
128            done = 1
129            break
130        # range is entirely below
131        if max_val < cur_min:
132            ret.append([min_val, max_val])
133            ret.append([cur_min, cur_max])
134            done = 1
135            break
136        # range is partially below
137        if min_val < cur_min:
138            if max_val <= cur_max:
139                ret.append([min_val, cur_max])
140                done = 1
141                break
142            else:
143                ret.append([min_val, max_val])
144                check_next = 1
145                continue
146        # range already included
147        if max_val <= cur_max:
148            ret.append([cur_min, cur_max])
149            done = 1
150            break;
151        # range partially above
152        if min_val <= cur_max:
153            ret.append([cur_min, max_val])
154            check_next = 1
155            continue
156        # join 2 ranges on the other side
157        if min_val - 1 == cur_max:
158            ret.append([cur_min, max_val])
159            check_next = 1
160            continue
161        # range is above
162        ret.append([cur_min, cur_max])
163
164    if idx + 1 < len(rl):          # we hit a break statement
165        ret = ret + rl[idx + 1:]
166    elif done:                     # we hit a break on the last iteration
167        pass
168    elif not check_next:           # it's past the end of the rl
169        ret.append([min_val, max_val])
170
171    return ret;
172
173def rl_union(rl1, rl2):
174    ret = []
175    for r in rl1:
176        ret = add_range(ret, r[0], r[1])
177    for r in rl2:
178        ret = add_range(ret, r[0], r[1])
179
180    if (rl1 or rl2) and not ret:
181        print "bug: merging %s + %s gives empty" %(rl1, rl2)
182
183    return ret
184
185def txt_to_val(txt):
186    if txt == "s64min":
187        return -(2**63)
188    elif txt == "s32min":
189        return -(2**31)
190    elif txt == "s16min":
191        return -(2**15)
192    elif txt == "s64max":
193        return 2**63 - 1
194    elif txt == "s32max":
195        return 2**31 - 1
196    elif txt == "s16max":
197        return 2**15 - 1
198    elif txt == "u64max":
199        return 2**64 - 1
200    elif txt == "u32max":
201        return 2**32 - 1
202    elif txt == "u16max":
203        return 2**16 - 1
204    else:
205        try:
206            return int(txt)
207        except ValueError:
208            return 0
209
210def val_to_txt(val):
211    if val == -(2**63):
212        return "s64min"
213    elif val == -(2**31):
214        return "s32min"
215    elif val == -(2**15):
216        return "s16min"
217    elif val == 2**63 - 1:
218        return "s64max"
219    elif val == 2**31 - 1:
220        return "s32max"
221    elif val == 2**15 - 1:
222        return "s16max"
223    elif val == 2**64 - 1:
224        return "u64max"
225    elif val == 2**32 - 1:
226        return "u32max"
227    elif val == 2**16 - 1:
228        return "u16max"
229    elif val < 0:
230        return "(%d)" %(val)
231    else:
232        return "%d" %(val)
233
234def get_next_str(txt):
235    val = ""
236    parsed = 0
237
238    if txt[0] == '(':
239        parsed += 1
240        for char in txt[1:]:
241            if char == ')':
242                break
243            parsed += 1
244        val = txt[1:parsed]
245        parsed += 1
246    elif txt[0] == 's' or txt[0] == 'u':
247        parsed += 6
248        val = txt[:parsed]
249    else:
250        if txt[0] == '-':
251            parsed += 1
252        for char in txt[parsed:]:
253            if char == '-':
254                break
255            parsed += 1
256        val = txt[:parsed]
257    return [parsed, val]
258
259def txt_to_rl(txt):
260    if len(txt) == 0:
261        return []
262
263    ret = []
264    pairs = txt.split(",")
265    for pair in pairs:
266        cnt, min_str = get_next_str(pair)
267        if cnt == len(pair):
268            max_str = min_str
269        else:
270            cnt, max_str = get_next_str(pair[cnt + 1:])
271        min_val = txt_to_val(min_str)
272        max_val = txt_to_val(max_str)
273        ret.append([min_val, max_val])
274
275#    Hm...  Smatch won't call INT_MAX s32max if the variable is unsigned.
276#    if txt != rl_to_txt(ret):
277#        print "bug: converting: text = %s rl = %s internal = %s" %(txt, rl_to_txt(ret), ret)
278
279    return ret
280
281def rl_to_txt(rl):
282    ret = ""
283    for idx in range(len(rl)):
284        cur_min = rl[idx][0]
285        cur_max = rl[idx][1]
286
287        if idx != 0:
288            ret += ","
289
290        if cur_min == cur_max:
291            ret += val_to_txt(cur_min)
292        else:
293            ret += val_to_txt(cur_min)
294            ret += "-"
295            ret += val_to_txt(cur_max)
296    return ret
297
298def type_to_str(type_int):
299
300    t = int(type_int)
301    if db_types.has_key(t):
302        return db_types[t]
303    return type_int
304
305def type_to_int(type_string):
306    for k in db_types.keys():
307        if db_types[k] == type_string:
308            return k
309    return -1
310
311def display_caller_info(printed, cur, param_names):
312    for txt in cur:
313        if not printed:
314            print "file | caller | function | type | parameter | key | value |"
315        printed = 1
316
317        parameter = int(txt[6])
318        key = txt[7]
319        if len(param_names) and parameter in param_names:
320            key = key.replace("$", param_names[parameter])
321
322        print "%20s | %20s | %20s |" %(txt[0], txt[1], txt[2]),
323        print " %10s |" %(type_to_str(txt[5])),
324        print " %d | %s | %s" %(parameter, key, txt[8])
325    return printed
326
327def get_caller_info(filename, ptrs, my_type):
328    cur = con.cursor()
329    param_names = get_param_names(filename, func)
330    printed = 0
331    type_filter = ""
332    if my_type != "":
333        type_filter = "and type = %d" %(type_to_int(my_type))
334    for ptr in ptrs:
335        cur.execute("select * from caller_info where function = '%s' %s;" %(ptr, type_filter))
336        printed = display_caller_info(printed, cur, param_names)
337
338def print_caller_info(filename, func, my_type = ""):
339    ptrs = get_function_pointers(func)
340    get_caller_info(filename, ptrs, my_type)
341
342def merge_values(param_names, vals, cur):
343    for txt in cur:
344        parameter = int(txt[0])
345        name = txt[1]
346        rl = txt_to_rl(txt[2])
347        if parameter in param_names:
348            name = name.replace("$", param_names[parameter])
349
350        if not parameter in vals:
351            vals[parameter] = {}
352
353        # the first item on the list is the number of rows.  it's incremented
354        # every time we call merge_values().
355        if name in vals[parameter]:
356            vals[parameter][name] = [vals[parameter][name][0] + 1, rl_union(vals[parameter][name][1], rl)]
357        else:
358            vals[parameter][name] = [1, rl]
359
360def get_param_names(filename, func):
361    cur = con.cursor()
362    param_names = {}
363    cur.execute("select parameter, value from parameter_name where file = '%s' and function = '%s';" %(filename, func))
364    for txt in cur:
365        parameter = int(txt[0])
366        name = txt[1]
367        param_names[parameter] = name
368    if len(param_names):
369        return param_names
370
371    cur.execute("select parameter, value from parameter_name where function = '%s';" %(func))
372    for txt in cur:
373        parameter = int(txt[0])
374        name = txt[1]
375        param_names[parameter] = name
376    return param_names
377
378def get_caller_count(ptrs):
379    cur = con.cursor()
380    count = 0
381    for ptr in ptrs:
382        cur.execute("select count(distinct(call_id)) from caller_info where function = '%s';" %(ptr))
383        for txt in cur:
384            count += int(txt[0])
385    return count
386
387def print_merged_caller_values(filename, func, ptrs, param_names, call_cnt):
388    cur = con.cursor()
389    vals = {}
390    for ptr in ptrs:
391        cur.execute("select parameter, key, value from caller_info where function = '%s' and type = %d;" %(ptr, type_to_int("PARAM_VALUE")))
392        merge_values(param_names, vals, cur);
393
394    for param in sorted(vals):
395        for name in sorted(vals[param]):
396            if vals[param][name][0] != call_cnt:
397                continue
398            print "%d %s -> %s" %(param, name, rl_to_txt(vals[param][name][1]))
399
400
401def print_unmerged_caller_values(filename, func, ptrs, param_names):
402    cur = con.cursor()
403    for ptr in ptrs:
404        prev = -1
405        cur.execute("select file, caller, call_id, parameter, key, value from caller_info where function = '%s' and type = %d;" %(ptr, type_to_int("PARAM_VALUE")))
406        for filename, caller, call_id, parameter, name, value in cur:
407            if prev != int(call_id):
408                prev = int(call_id)
409
410            parameter = int(parameter)
411            if parameter < len(param_names):
412                name = name.replace("$", param_names[parameter])
413            else:
414                name = name.replace("$", "$%d" %(parameter))
415
416            print "%s | %s | %s | %s" %(filename, caller, name, value)
417        print "=========================="
418
419def print_caller_values(filename, func, ptrs):
420    param_names = get_param_names(filename, func)
421    call_cnt = get_caller_count(ptrs)
422
423    print_merged_caller_values(filename, func, ptrs, param_names, call_cnt)
424    print "=========================="
425    print_unmerged_caller_values(filename, func, ptrs, param_names)
426
427def caller_info_values(filename, func):
428    ptrs = get_function_pointers(func)
429    print_caller_values(filename, func, ptrs)
430
431def print_return_states(func):
432    cur = con.cursor()
433    cur.execute("select * from return_states where function = '%s';" %(func))
434    count = 0
435    for txt in cur:
436        printed = 1
437        if count == 0:
438            print "file | function | return_id | return_value | type | param | key | value |"
439        count += 1
440        print "%s | %s | %2s | %13s" %(txt[0], txt[1], txt[3], txt[4]),
441        print "| %13s |" %(type_to_str(txt[6])),
442        print " %2d | %20s | %20s |" %(txt[7], txt[8], txt[9])
443
444def print_return_implies(func):
445    cur = con.cursor()
446    cur.execute("select * from return_implies where function = '%s';" %(func))
447    count = 0
448    for txt in cur:
449        if not count:
450            print "file | function | type | param | key | value |"
451        count += 1
452        print "%15s | %15s" %(txt[0], txt[1]),
453        print "| %15s" %(type_to_str(txt[4])),
454        print "| %3d | %s | %15s |" %(txt[5], txt[6], txt[7])
455
456def print_type_size(struct_type, member):
457    cur = con.cursor()
458    cur.execute("select * from type_size where type like '(struct %s)->%s';" %(struct_type, member))
459    print "type | size"
460    for txt in cur:
461        print "%-15s | %s" %(txt[0], txt[1])
462
463    cur.execute("select * from function_type_size where type like '(struct %s)->%s';" %(struct_type, member))
464    print "file | function | type | size"
465    for txt in cur:
466        print "%-15s | %-15s | %-15s | %s" %(txt[0], txt[1], txt[2], txt[3])
467
468def print_data_info(struct_type, member):
469    cur = con.cursor()
470    cur.execute("select * from data_info where data like '(struct %s)->%s';" %(struct_type, member))
471    print "file | data | type | value"
472    for txt in cur:
473        print "%-15s | %-15s | %-15s | %s" %(txt[0], txt[1], type_to_str(txt[2]), txt[3])
474
475def print_fn_ptrs(func):
476    ptrs = get_function_pointers(func)
477    if not ptrs:
478        return
479    print "%s = " %(func),
480    print(ptrs)
481
482def print_functions(member):
483    cur = con.cursor()
484    cur.execute("select * from function_ptr where ptr like '%%->%s';" %(member))
485    print "File | Pointer | Function | Static"
486    for txt in cur:
487        print "%-15s | %-15s | %-15s | %s" %(txt[0], txt[2], txt[1], txt[3])
488
489def get_callers(func):
490    ret = []
491    cur = con.cursor()
492    ptrs = get_function_pointers(func)
493    for ptr in ptrs:
494        cur.execute("select distinct caller from caller_info where function = '%s';" %(ptr))
495        for row in cur:
496            ret.append(row[0])
497    return ret
498
499printed_funcs = []
500def call_tree_helper(func, indent = 0):
501    global printed_funcs
502    if func in printed_funcs:
503        return
504    print "%s%s()" %(" " * indent, func)
505    if func == "too common":
506        return
507    if indent > 6:
508        return
509    printed_funcs.append(func)
510    callers = get_callers(func)
511    if len(callers) >= 20:
512        print "Over 20 callers for %s()" %(func)
513        return
514    for caller in callers:
515        call_tree_helper(caller, indent + 2)
516
517def print_call_tree(func):
518    global printed_funcs
519    printed_funcs = []
520    call_tree_helper(func)
521
522def function_type_value(struct_type, member):
523    cur = con.cursor()
524    cur.execute("select * from function_type_value where type like '(struct %s)->%s';" %(struct_type, member))
525    for txt in cur:
526        print "%-30s | %-30s | %s | %s" %(txt[0], txt[1], txt[2], txt[3])
527
528def trace_callers(func, param):
529    sources = []
530    prev_type = 0
531
532    cur = con.cursor()
533    ptrs = get_function_pointers(func)
534    for ptr in ptrs:
535        cur.execute("select type, caller, value from caller_info where function = '%s' and (type = 0 or type = 1014 or type = 1028) and (parameter = -1 or parameter = %d);" %(ptr, param))
536        for row in cur:
537            data_type = int(row[0])
538            if data_type == 1014:
539                sources.append((row[1], row[2]))
540            elif data_type == 1028:
541                sources.append(("%", row[2])) # hack...
542            elif data_type == 0 and prev_type == 0:
543                sources.append((row[1], ""))
544            prev_type = data_type
545    return sources
546
547def trace_param_helper(func, param, indent = 0):
548    global printed_funcs
549    if func in printed_funcs:
550        return
551    print "%s%s(param %d)" %(" " * indent, func, param)
552    if func == "too common":
553        return
554    if indent > 20:
555        return
556    printed_funcs.append(func)
557    sources = trace_callers(func, param)
558    for path in sources:
559
560        if len(path[1]) and path[1][0] == 'p' and path[1][1] == ' ':
561            p = int(path[1][2:])
562            trace_param_helper(path[0], p, indent + 2)
563        elif len(path[0]) and path[0][0] == '%':
564            print "  %s%s" %(" " * indent, path[1])
565        else:
566            print "* %s%s %s" %(" " * (indent - 1), path[0], path[1])
567
568def trace_param(func, param):
569    global printed_funcs
570    printed_funcs = []
571    print "tracing %s %d" %(func, param)
572    trace_param_helper(func, param)
573
574def print_locals(filename):
575    cur = con.cursor()
576    cur.execute("select file,data,value from data_info where file = '%s' and type = 8029 and value != 0;" %(filename))
577    for txt in cur:
578        print "%s | %s | %s" %(txt[0], txt[1], txt[2])
579
580def constraint(struct_type, member):
581    cur = con.cursor()
582    cur.execute("select * from constraints_required where data like '(struct %s)->%s' or bound like '(struct %s)->%s';" %(struct_type, member, struct_type, member))
583    for txt in cur:
584        print "%-30s | %-30s | %s | %s" %(txt[0], txt[1], txt[2], txt[3])
585
586if len(sys.argv) < 2:
587    usage()
588
589if len(sys.argv) == 2:
590    func = sys.argv[1]
591    print_caller_info("", func)
592elif sys.argv[1] == "call_info":
593    if len(sys.argv) != 4:
594        usage()
595    filename = sys.argv[2]
596    func = sys.argv[3]
597    caller_info_values(filename, func)
598    print_caller_info(filename, func)
599elif sys.argv[1] == "user_data":
600    func = sys.argv[2]
601    print_caller_info(filename, func, "USER_DATA")
602elif sys.argv[1] == "param_value":
603    func = sys.argv[2]
604    print_caller_info(filename, func, "PARAM_VALUE")
605elif sys.argv[1] == "function_ptr" or sys.argv[1] == "fn_ptr":
606    func = sys.argv[2]
607    print_fn_ptrs(func)
608elif sys.argv[1] == "return_states":
609    func = sys.argv[2]
610    print_return_states(func)
611    print "================================================"
612    print_return_implies(func)
613elif sys.argv[1] == "return_implies":
614    func = sys.argv[2]
615    print_return_implies(func)
616elif sys.argv[1] == "type_size" or sys.argv[1] == "buf_size":
617    struct_type = sys.argv[2]
618    member = sys.argv[3]
619    print_type_size(struct_type, member)
620elif sys.argv[1] == "data_info":
621    struct_type = sys.argv[2]
622    member = sys.argv[3]
623    print_data_info(struct_type, member)
624elif sys.argv[1] == "call_tree":
625    func = sys.argv[2]
626    print_call_tree(func)
627elif sys.argv[1] == "where":
628    if len(sys.argv) == 3:
629        struct_type = "%"
630        member = sys.argv[2]
631    elif len(sys.argv) == 4:
632        struct_type = sys.argv[2]
633        member = sys.argv[3]
634    function_type_value(struct_type, member)
635elif sys.argv[1] == "local":
636    filename = sys.argv[2]
637    variable = ""
638    if len(sys.argv) == 4:
639        variable = sys.argv[3]
640    local_values(filename, variable)
641elif sys.argv[1] == "functions":
642    member = sys.argv[2]
643    print_functions(member)
644elif sys.argv[1] == "trace_param":
645    if len(sys.argv) != 4:
646        usage()
647    func = sys.argv[2]
648    param = int(sys.argv[3])
649    trace_param(func, param)
650elif sys.argv[1] == "locals":
651    if len(sys.argv) != 3:
652        usage()
653    filename = sys.argv[2]
654    print_locals(filename);
655elif sys.argv[1] == "constraint":
656    if len(sys.argv) == 3:
657        struct_type = "%"
658        member = sys.argv[2]
659    elif len(sys.argv) == 4:
660        struct_type = sys.argv[2]
661        member = sys.argv[3]
662    constraint(struct_type, member)
663elif sys.argv[1] == "test":
664    filename = sys.argv[2]
665    func = sys.argv[3]
666    caller_info_values(filename, func)
667else:
668    usage()
669