xref: /linux/drivers/gpu/drm/msm/registers/gen_header.py (revision bba2c3615bd6cfee7456d1130f2e6b01b3f4e9ba)
1#!/usr/bin/python3
2#
3# Copyright © 2019-2024 Google, Inc.
4#
5# SPDX-License-Identifier: MIT
6
7import xml.parsers.expat
8import sys
9import os
10import collections
11import argparse
12import time
13import datetime
14import json
15
16
17class Error(Exception):
18    def __init__(self, message):
19        self.message = message
20
21
22class Enum(object):
23    def __init__(self, name):
24        self.name = name
25        self.values = []
26
27    def has_name(self, name):
28        for (n, value) in self.values:
29            if n == name:
30                return True
31        return False
32
33    def names(self):
34        return [n for (n, value) in self.values]
35
36    def value(self, name):
37        for (n, v) in self.values:
38            if n == name:
39                return v
40
41    def dump(self, has_variants):
42        use_hex = False
43        for (name, value) in self.values:
44            if value > 0x1000:
45                use_hex = True
46
47        print("enum %s {" % self.name)
48        for (name, value) in self.values:
49            if use_hex:
50                print("\t%s = 0x%08x," % (name, value))
51            else:
52                print("\t%s = %d," % (name, value))
53        print("};\n")
54
55    def dump_pack_struct(self, has_variants):
56        pass
57
58
59class Field(object):
60    def __init__(self, name, low, high, shr, type, parser):
61        self.name = name
62        self.low = low
63        self.high = high
64        self.shr = shr
65        self.type = type
66
67        builtin_types = [None, "a3xx_regid", "boolean", "uint", "hex",
68                         "int", "fixed", "ufixed", "float", "address", "waddress"]
69
70        maxpos = parser.current_bitsize - 1
71
72        if low < 0 or low > maxpos:
73            raise parser.error("low attribute out of range: %d" % low)
74        if high < 0 or high > maxpos:
75            raise parser.error("high attribute out of range: %d" % high)
76        if high < low:
77            raise parser.error(
78                "low is greater than high: low=%d, high=%d" % (low, high))
79        if self.type == "boolean" and not low == high:
80            raise parser.error("booleans should be 1 bit fields")
81        elif self.type == "float" and not (high - low == 31 or high - low == 15):
82            raise parser.error("floats should be 16 or 32 bit fields")
83        elif self.type not in builtin_types and self.type not in parser.enums:
84            raise parser.error("unknown type '%s'" % self.type)
85
86    def ctype(self, var_name):
87        if self.type is None:
88            type = "uint32_t"
89            val = var_name
90        elif self.type == "boolean":
91            type = "bool"
92            val = var_name
93        elif self.type == "uint" or self.type == "hex" or self.type == "a3xx_regid":
94            type = "uint32_t"
95            val = var_name
96        elif self.type == "int":
97            type = "int32_t"
98            val = var_name
99        elif self.type == "fixed":
100            type = "float"
101            val = "(uint32_t)((int32_t)(%s * %d.0))" % (var_name, 1 << self.radix)
102        elif self.type == "ufixed":
103            type = "float"
104            val = "((uint32_t)(%s * %d.0))" % (var_name, 1 << self.radix)
105        elif self.type == "float" and self.high - self.low == 31:
106            type = "float"
107            val = "fui(%s)" % var_name
108        elif self.type == "float" and self.high - self.low == 15:
109            type = "float"
110            val = "_mesa_float_to_half(%s)" % var_name
111        elif self.type in ["address", "waddress"]:
112            type = "uint64_t"
113            val = var_name
114        else:
115            type = "enum %s" % self.type
116            val = var_name
117
118        if self.shr > 0:
119            val = "(%s >> %d)" % (val, self.shr)
120
121        return (type, val)
122
123
124def tab_to(name, value):
125    tab_count = (68 - (len(name) & ~7)) // 8
126    if tab_count <= 0:
127        tab_count = 1
128    print(name + ('\t' * tab_count) + value)
129
130def define_macro(name, value, has_variants):
131    if has_variants:
132        value = "__FD_DEPRECATED " + value
133    tab_to(name, value)
134
135def mask(low, high):
136    return ((0xffffffffffffffff >> (64 - (high + 1 - low))) << low)
137
138
139def field_name(reg, f):
140    if f.name:
141        name = f.name.lower()
142    else:
143        # We hit this path when a reg is defined with no bitset fields, ie.
144        # 	<reg32 offset="0x88db" name="RB_RESOLVE_SYSTEM_BUFFER_ARRAY_PITCH" low="0" high="28" shr="6" type="uint"/>
145        name = reg.name.lower()
146
147    if (name in ["double", "float", "int"]) or not (name[0].isalpha()):
148        name = "_" + name
149
150    return name
151
152# indices - array of (ctype, stride, __offsets_NAME)
153
154
155def indices_varlist(indices):
156    return ", ".join(["i%d" % i for i in range(len(indices))])
157
158
159def indices_prototype(indices):
160    return ", ".join(["%s i%d" % (ctype, idx)
161                      for (idx, (ctype, stride, offset)) in enumerate(indices)])
162
163
164def indices_strides(indices):
165    return " + ".join(["0x%x*i%d" % (stride, idx)
166                       if stride else
167                       "%s(i%d)" % (offset, idx)
168                       for (idx, (ctype, stride, offset)) in enumerate(indices)])
169
170
171def is_number(str):
172    try:
173        int(str)
174        return True
175    except ValueError:
176        return False
177
178
179def sanitize_variant(variant):
180    if variant and "-" in variant:
181        return variant[:variant.index("-")]
182    return variant
183
184
185class Bitset(object):
186    def __init__(self, name, template):
187        self.name = name
188        self.inline = False
189        self.reg = None
190        if template:
191            self.fields = template.fields[:]
192        else:
193            self.fields = []
194
195    # Get address field if there is one in the bitset, else return None:
196    def get_address_field(self):
197        for f in self.fields:
198            if f.type in ["address", "waddress"]:
199                return f
200        return None
201
202    def dump_regpair_builder(self, reg):
203        print("#ifndef NDEBUG")
204        known_mask = 0
205        for f in self.fields:
206            known_mask |= mask(f.low, f.high)
207            if f.type in ["boolean", "address", "waddress"]:
208                continue
209            type, val = f.ctype("fields.%s" % field_name(reg, f))
210            print("    assert((%-40s & 0x%08x) == 0);" %
211                  (val, 0xffffffff ^ mask(0, f.high - f.low)))
212        print("    assert((%-40s & 0x%08x) == 0);" %
213              ("fields.unknown", known_mask))
214        print("#endif\n")
215
216        print("    return (struct fd_reg_pair) {")
217        print("        .reg = (uint32_t)%s," % reg.reg_offset())
218        print("        .value =")
219        cast = "(uint64_t)" if reg.bit_size == 64 else ""
220        for f in self.fields:
221            if f.type in ["address", "waddress"]:
222                continue
223            else:
224                type, val = f.ctype("fields.%s" % field_name(reg, f))
225                print("            (%s%-40s << %2d) |" % (cast, val, f.low))
226        value_name = "dword"
227        if reg.bit_size == 64:
228            value_name = "qword"
229        print("            fields.unknown | fields.%s," % (value_name,))
230
231        address = self.get_address_field()
232        if address:
233            print("#ifndef TU_CS_H")
234            print("        .bo = fields.bo,")
235            print("        .is_address = true,")
236            print("        .bo_offset = fields.bo_offset,")
237            print("        .bo_shift = %d," % address.shr)
238            print("        .bo_low = %d," % address.low)
239            print("#else")
240            print("        .is_address = true,")
241            print("#endif")
242
243        print("    };")
244
245    def dump_pack_struct(self, has_variants, reg=None):
246        if not reg:
247            return
248
249        prefix = reg.full_name
250
251        constexpr_mark = " CONSTEXPR"
252
253        print("struct %s {" % prefix)
254        for f in self.fields:
255            if f.type in ["address", "waddress"]:
256                print("#ifndef TU_CS_H")
257                tab_to("    __bo_type", "bo;")
258                tab_to("    uint32_t", "bo_offset;")
259                print("#endif\n")
260                continue
261            name = field_name(reg, f)
262
263            type, val = f.ctype("var")
264
265            tab_to("    %s" % type, "%s;" % name)
266
267            if f.type == "float":
268                # Requires using `fui()` or `_mesa_float_to_half()`
269                constexpr_mark = ""
270        if reg.bit_size == 64:
271            tab_to("    uint64_t", "qword;")
272            tab_to("    uint64_t", "unknown;")
273        else:
274            tab_to("    uint32_t", "dword;")
275            tab_to("    uint32_t", "unknown;")
276        print("};\n")
277
278        if not has_variants:
279            print("static%s inline struct fd_reg_pair" % constexpr_mark)
280            if reg.array:
281                print("pack_%s(uint32_t __i, struct %s fields)\n{" % (prefix, prefix))
282            else:
283                print("pack_%s(struct %s fields)\n{" % (prefix, prefix))
284
285            self.dump_regpair_builder(reg)
286
287            print("\n}\n")
288
289        if self.get_address_field():
290            skip = ", { .reg = 0 }"
291        else:
292            skip = ""
293
294        if reg.array:
295            print("#define %s(__i, ...) pack_%s(__i, __struct_cast(%s) { __VA_ARGS__ })%s\n" %
296                  (prefix, prefix, prefix, skip))
297        else:
298            print("#define %s(...) pack_%s(__struct_cast(%s) { __VA_ARGS__ })%s\n" %
299                  (prefix, prefix, prefix, skip))
300
301    def dump(self, has_variants, prefix=None, reg=None):
302        if prefix is None:
303            prefix = self.name
304        suffix = ""
305        if self.reg and self.reg.bit_size == 64:
306            print(
307                "static CONSTEXPR inline uint32_t %s_LO(uint32_t val)\n{" % prefix)
308            print("\treturn val;\n}")
309            print(
310                "static CONSTEXPR inline uint32_t %s_HI(uint32_t val)\n{" % prefix)
311            print("\treturn val;\n}")
312            suffix = "ull"
313
314        for f in self.fields:
315            if f.name:
316                name = prefix + "_" + f.name
317            else:
318                name = prefix
319
320            if not f.name and f.low == 0 and f.shr == 0 and f.type not in ["float", "fixed", "ufixed"]:
321                pass
322            elif f.type == "boolean" or (f.type is None and f.low == f.high):
323                tab_to("#define %s" % name, "0x%08x%s" % ((1 << f.low), suffix))
324            else:
325                tab_to("#define %s__MASK" %
326                       name, "0x%08x%s" % (mask(f.low, f.high), suffix))
327                tab_to("#define %s__SHIFT" % name, "%d" % f.low)
328                type, val = f.ctype("val")
329                ret_type = "uint64_t" if reg and reg.bit_size == 64 else "uint32_t"
330                cast = "(uint64_t)" if reg and reg.bit_size == 64 else ""
331
332                constexpr_mark = "" if type == "float" else " CONSTEXPR"
333                print("static%s inline %s %s(%s val)\n{" % (
334                    constexpr_mark, ret_type, name, type))
335                if f.shr > 0:
336                    print("\tassert(!(val & 0x%x));" % mask(0, f.shr - 1))
337                print("\treturn (%s(%s) << %s__SHIFT) & %s__MASK;\n}" %
338                      (cast, val, name, name))
339        print()
340
341
342class Array(object):
343    def __init__(self, attrs, domain, variant, parent, index_type):
344        if "name" in attrs:
345            self.local_name = attrs["name"]
346        else:
347            self.local_name = ""
348        self.domain = domain
349        self.variant = variant
350        self.parent = parent
351        self.children = []
352        if self.parent:
353            self.name = self.parent.name + "_" + self.local_name
354        else:
355            self.name = self.local_name
356        if "offsets" in attrs:
357            self.offsets = map(lambda i: "0x%08x" %
358                               int(i, 0), attrs["offsets"].split(","))
359            self.fixed_offsets = True
360        elif "doffsets" in attrs:
361            self.offsets = map(lambda s: "(%s)" %
362                               s, attrs["doffsets"].split(","))
363            self.fixed_offsets = True
364        else:
365            self.offset = int(attrs["offset"], 0)
366            self.stride = int(attrs["stride"], 0)
367            self.fixed_offsets = False
368        if "index" in attrs:
369            self.index_type = index_type
370        else:
371            self.index_type = None
372        self.length = int(attrs["length"], 0)
373        if "usage" in attrs:
374            self.usages = attrs["usage"].split(',')
375        else:
376            self.usages = None
377
378    def index_ctype(self):
379        if not self.index_type:
380            return "uint32_t"
381        else:
382            return "enum %s" % self.index_type.name
383
384    # Generate array of (ctype, stride, __offsets_NAME)
385    def indices(self):
386        if self.parent:
387            indices = self.parent.indices()
388        else:
389            indices = []
390        if self.length != 1:
391            if self.fixed_offsets:
392                indices.append((self.index_ctype(), None,
393                               "__offset_%s" % self.local_name))
394            else:
395                indices.append((self.index_ctype(), self.stride, None))
396        return indices
397
398    def total_offset(self):
399        offset = 0
400        if not self.fixed_offsets:
401            offset += self.offset
402        if self.parent:
403            offset += self.parent.total_offset()
404        return offset
405
406    def dump(self, has_variants):
407        proto = indices_varlist(self.indices())
408        strides = indices_strides(self.indices())
409        array_offset = self.total_offset()
410        if self.fixed_offsets and not has_variants:
411            print("static CONSTEXPR inline uint32_t __offset_%s(%s idx)" %
412                  (self.local_name, self.index_ctype()))
413            print("{\n\tswitch (idx) {")
414            if self.index_type:
415                for val, offset in zip(self.index_type.names(), self.offsets):
416                    print("\t\tcase %s: return %s;" % (val, offset))
417            else:
418                for idx, offset in enumerate(self.offsets):
419                    print("\t\tcase %d: return %s;" % (idx, offset))
420            print("\t\tdefault: return INVALID_IDX(idx);")
421            print("\t}\n}")
422        if proto == '':
423            define_macro("#define REG_%s_%s" %
424                         (self.domain, self.name), "0x%08x\n" % array_offset,
425                         has_variants)
426        else:
427            define_macro("#define REG_%s_%s(%s)" % (self.domain, self.name,
428                         proto), "(0x%08x + %s )\n" % (array_offset, strides),
429                         has_variants)
430
431    def dump_pack_struct(self, has_variants):
432        pass
433
434    def dump_regpair_builder(self):
435        pass
436
437
438class Reg(object):
439    def __init__(self, attrs, domain, array, bit_size):
440        self.name = attrs["name"]
441        self.domain = domain
442        self.array = array
443        self.offset = int(attrs["offset"], 0)
444        self.type = None
445        self.bit_size = bit_size
446        if array:
447            self.name = array.name + "_" + self.name
448            array.children.append(self)
449        self.full_name = self.domain + "_" + self.name
450        if "stride" in attrs:
451            self.stride = int(attrs["stride"], 0)
452            self.length = int(attrs["length"], 0)
453        else:
454            self.stride = None
455            self.length = None
456
457    # Generate array of (ctype, stride, __offsets_NAME)
458    def indices(self):
459        if self.array:
460            indices = self.array.indices()
461        else:
462            indices = []
463        if self.stride:
464            indices.append(("uint32_t", self.stride, None))
465        return indices
466
467    def total_offset(self):
468        if self.array:
469            return self.array.total_offset() + self.offset
470        else:
471            return self.offset
472
473    def reg_offset(self):
474        if self.array:
475            offset = self.array.offset + self.offset
476            return "(0x%08x + 0x%x*__i)" % (offset, self.array.stride)
477        return "0x%08x" % self.offset
478
479    def dump(self, has_variants):
480        proto = indices_prototype(self.indices())
481        strides = indices_strides(self.indices())
482        offset = self.total_offset()
483        if proto == '':
484            define_macro("#define REG_%s" % self.full_name, "0x%08x" % offset, has_variants)
485        elif not has_variants:
486            depcrstr = ""
487            if has_variants:
488                depcrstr = " __FD_DEPRECATED "
489            print("static CONSTEXPR inline%s uint32_t REG_%s(%s) { return 0x%08x + %s; }" % (
490                  depcrstr, self.full_name, proto, offset, strides))
491
492        if self.bitset.inline:
493            self.bitset.dump(has_variants, self.full_name, self)
494        print("")
495
496    def dump_pack_struct(self, has_variants):
497        if self.bitset.inline:
498            self.bitset.dump_pack_struct(has_variants, self)
499
500    def dump_regpair_builder(self):
501        self.bitset.dump_regpair_builder(self)
502
503    def dump_py(self):
504        offset = self.offset
505        if self.array:
506            offset += self.array.offset
507        print("\tREG_%s = 0x%08x" % (self.full_name, offset))
508
509
510class Parser(object):
511    def __init__(self):
512        self.current_array = None
513        self.current_domain = None
514        self.current_prefix = None
515        self.current_prefix_type = None
516        self.current_stripe = None
517        self.current_bitset = None
518        self.current_bitsize = 32
519        # The varset attribute on the domain specifies the enum which
520        # specifies all possible hw variants:
521        self.current_varset = None
522        # Regs that have multiple variants.. we only generated the C++
523        # template based struct-packers for these
524        self.variant_regs = {}
525        # Information in which contexts regs are used, to be used in
526        # debug options
527        self.usage_regs = collections.defaultdict(list)
528        self.bitsets = {}
529        self.enums = {}
530        self.variants = set()
531        self.file = []
532        self.xml_files = []
533
534    def error(self, message):
535        parser, filename = self.stack[-1]
536        return Error("%s:%d:%d: %s" % (filename, parser.CurrentLineNumber, parser.CurrentColumnNumber, message))
537
538    def prefix(self, variant=None):
539        if self.current_prefix_type == "variant" and variant:
540            return sanitize_variant(variant)
541        elif self.current_stripe:
542            return self.current_stripe + "_" + self.current_domain
543        elif self.current_prefix:
544            return self.current_prefix + "_" + self.current_domain
545        else:
546            return self.current_domain
547
548    def parse_field(self, name, attrs):
549        try:
550            if "pos" in attrs:
551                high = low = int(attrs["pos"], 0)
552            elif "high" in attrs and "low" in attrs:
553                high = int(attrs["high"], 0)
554                low = int(attrs["low"], 0)
555            else:
556                low = 0
557                high = self.current_bitsize - 1
558
559            if "type" in attrs:
560                type = attrs["type"]
561            else:
562                type = None
563
564            if "shr" in attrs:
565                shr = int(attrs["shr"], 0)
566            else:
567                shr = 0
568
569            b = Field(name, low, high, shr, type, self)
570
571            if type == "fixed" or type == "ufixed":
572                b.radix = int(attrs["radix"], 0)
573
574            self.current_bitset.fields.append(b)
575        except ValueError as e:
576            raise self.error(e)
577
578    def parse_varset(self, attrs):
579        # Inherit the varset from the enclosing domain if not overriden:
580        varset = self.current_varset
581        if "varset" in attrs:
582            varset = self.enums[attrs["varset"]]
583        return varset
584
585    def parse_variants(self, attrs):
586        if "variants" not in attrs:
587            return None
588
589        variant = attrs["variants"].split(",")[0]
590        varset = self.parse_varset(attrs)
591
592        if "-" in variant:
593            # if we have a range, validate that both the start and end
594            # of the range are valid enums:
595            start = variant[:variant.index("-")]
596            end = variant[variant.index("-") + 1:]
597            assert varset.has_name(start)
598            if end != "":
599                assert varset.has_name(end)
600        else:
601            assert varset.has_name(variant)
602
603        return variant
604
605    def add_all_variants(self, reg, attrs, parent_variant):
606        # TODO this should really handle *all* variants, including dealing
607        # with open ended ranges (ie. "A2XX,A4XX-") (we have the varset
608        # enum now to make that possible)
609        variant = self.parse_variants(attrs)
610        if not variant:
611            variant = parent_variant
612
613        if reg.name not in self.variant_regs:
614            self.variant_regs[reg.name] = {}
615        else:
616            # All variants must be same size:
617            v = next(iter(self.variant_regs[reg.name]))
618            assert self.variant_regs[reg.name][v].bit_size == reg.bit_size
619
620        self.variant_regs[reg.name][variant] = reg
621
622    def add_all_usages(self, reg, usages):
623        if not usages:
624            return
625
626        for usage in usages:
627            self.usage_regs[usage].append(reg)
628
629        self.variants.add(reg.domain)
630
631    def do_validate(self, schemafile):
632        if not self.validate:
633            return
634
635        try:
636            from lxml import etree
637
638            parser, filename = self.stack[-1]
639            dirname = os.path.dirname(filename)
640
641            # we expect this to look like <namespace url> schema.xsd.. I think
642            # technically it is supposed to be just a URL, but that doesn't
643            # quite match up to what we do.. Just skip over everything up to
644            # and including the first whitespace character:
645            schemafile = schemafile[schemafile.rindex(" ")+1:]
646
647            # this is a bit cheezy, but the xml file to validate could be
648            # in a child director, ie. we don't really know where the schema
649            # file is, the way the rnn C code does.  So if it doesn't exist
650            # just look one level up
651            if not os.path.exists(dirname + "/" + schemafile):
652                schemafile = "../" + schemafile
653
654            if not os.path.exists(dirname + "/" + schemafile):
655                raise self.error("Cannot find schema for: " + filename)
656
657            xmlschema_doc = etree.parse(dirname + "/" + schemafile)
658            xmlschema = etree.XMLSchema(xmlschema_doc)
659
660            xml_doc = etree.parse(filename)
661            if not xmlschema.validate(xml_doc):
662                error_str = str(xmlschema.error_log.filter_from_errors()[0])
663                raise self.error(
664                    "Schema validation failed for: " + filename + "\n" + error_str)
665        except ImportError as e:
666            print("lxml not found, skipping validation", file=sys.stderr)
667
668    def do_parse(self, filename):
669        filepath = os.path.abspath(filename)
670        if filepath in self.xml_files:
671            return
672        self.xml_files.append(filepath)
673        file = open(filename, "rb")
674        parser = xml.parsers.expat.ParserCreate()
675        self.stack.append((parser, filename))
676        parser.StartElementHandler = self.start_element
677        parser.EndElementHandler = self.end_element
678        parser.CharacterDataHandler = self.character_data
679        parser.buffer_text = True
680        parser.ParseFile(file)
681        self.stack.pop()
682        file.close()
683
684    def parse(self, rnn_path, filename, validate):
685        self.path = rnn_path
686        self.stack = []
687        self.validate = validate
688        self.do_parse(filename)
689
690    def parse_reg(self, attrs, bit_size):
691        self.current_bitsize = bit_size
692        if "type" in attrs and attrs["type"] in self.bitsets:
693            bitset = self.bitsets[attrs["type"]]
694            if bitset.inline:
695                self.current_bitset = Bitset(attrs["name"], bitset)
696                self.current_bitset.inline = True
697            else:
698                self.current_bitset = bitset
699        else:
700            self.current_bitset = Bitset(attrs["name"], None)
701            self.current_bitset.inline = True
702            if "type" in attrs:
703                self.parse_field(None, attrs)
704
705        variant = self.parse_variants(attrs)
706        if not variant and self.current_array:
707            variant = self.current_array.variant
708
709        self.current_reg = Reg(attrs, self.prefix(
710            variant), self.current_array, bit_size)
711        self.current_reg.bitset = self.current_bitset
712        self.current_bitset.reg = self.current_reg
713
714        if len(self.stack) == 1:
715            self.file.append(self.current_reg)
716
717        if variant is not None:
718            self.add_all_variants(self.current_reg, attrs, variant)
719
720        usages = None
721        if "usage" in attrs:
722            usages = attrs["usage"].split(',')
723        elif self.current_array:
724            usages = self.current_array.usages
725
726        self.add_all_usages(self.current_reg, usages)
727
728    def start_element(self, name, attrs):
729        self.cdata = ""
730        if name == "import":
731            filename = attrs["file"]
732            self.do_parse(os.path.join(self.path, filename))
733        elif name == "domain":
734            self.current_domain = attrs["name"]
735            if "prefix" in attrs:
736                self.current_prefix = sanitize_variant(
737                    self.parse_variants(attrs))
738                self.current_prefix_type = attrs["prefix"]
739            else:
740                self.current_prefix = None
741                self.current_prefix_type = None
742            if "varset" in attrs:
743                self.current_varset = self.enums[attrs["varset"]]
744        elif name == "stripe":
745            self.current_stripe = sanitize_variant(self.parse_variants(attrs))
746        elif name == "enum":
747            self.current_enum_value = 0
748            self.current_enum = Enum(attrs["name"])
749            self.enums[attrs["name"]] = self.current_enum
750            if len(self.stack) == 1:
751                self.file.append(self.current_enum)
752        elif name == "value":
753            if "value" in attrs:
754                value = int(attrs["value"], 0)
755            else:
756                value = self.current_enum_value
757            self.current_enum.values.append((attrs["name"], value))
758        elif name == "reg32":
759            self.parse_reg(attrs, 32)
760        elif name == "reg64":
761            self.parse_reg(attrs, 64)
762        elif name == "array":
763            self.current_bitsize = 32
764            variant = self.parse_variants(attrs)
765            index_type = self.enums[attrs["index"]
766                                    ] if "index" in attrs else None
767            self.current_array = Array(attrs, self.prefix(
768                variant), variant, self.current_array, index_type)
769            if len(self.stack) == 1:
770                self.file.append(self.current_array)
771        elif name == "bitset":
772            self.current_bitset = Bitset(attrs["name"], None)
773            if "inline" in attrs and attrs["inline"] == "yes":
774                self.current_bitset.inline = True
775            self.bitsets[self.current_bitset.name] = self.current_bitset
776            if len(self.stack) == 1 and not self.current_bitset.inline:
777                self.file.append(self.current_bitset)
778        elif name == "bitfield" and self.current_bitset:
779            self.parse_field(attrs["name"], attrs)
780        elif name == "database":
781            self.do_validate(attrs["xsi:schemaLocation"])
782
783    def end_element(self, name):
784        if name == "domain":
785            self.current_domain = None
786            self.current_prefix = None
787            self.current_prefix_type = None
788        elif name == "stripe":
789            self.current_stripe = None
790        elif name == "bitset":
791            self.current_bitset = None
792        elif name == "reg32":
793            self.current_reg = None
794        elif name == "array":
795            # if the array has no Reg children, push an implicit reg32:
796            if len(self.current_array.children) == 0:
797                attrs = {
798                    "name": "REG",
799                    "offset": "0",
800                }
801                self.parse_reg(attrs, 32)
802            self.current_array = self.current_array.parent
803        elif name == "enum":
804            self.current_enum = None
805
806    def character_data(self, data):
807        self.cdata += data
808
809    def dump_reg_usages(self):
810        d = collections.defaultdict(list)
811        for usage, regs in self.usage_regs.items():
812            for reg in regs:
813                variants = self.variant_regs.get(reg.name)
814                if variants:
815                    for variant, vreg in variants.items():
816                        if reg == vreg:
817                            d[(usage, sanitize_variant(variant))].append(reg)
818                else:
819                    for variant in self.variants:
820                        d[(usage, sanitize_variant(variant))].append(reg)
821
822        print("#ifdef __cplusplus")
823
824        for usage, regs in self.usage_regs.items():
825            print("template<chip CHIP> constexpr inline uint16_t %s_REGS[] = {};" % (
826                usage.upper()))
827
828        for (usage, variant), regs in d.items():
829            offsets = []
830
831            for reg in regs:
832                if reg.array:
833                    for i in range(reg.array.length):
834                        offsets.append(reg.array.offset +
835                                       reg.offset + i * reg.array.stride)
836                        if reg.bit_size == 64:
837                            offsets.append(offsets[-1] + 1)
838                else:
839                    offsets.append(reg.offset)
840                    if reg.bit_size == 64:
841                        offsets.append(offsets[-1] + 1)
842
843            offsets.sort()
844
845            print("template<> constexpr inline uint16_t %s_REGS<%s>[] = {" % (
846                usage.upper(), variant))
847            for offset in offsets:
848                print("\t%s," % hex(offset))
849            print("};")
850
851        print("#endif")
852
853    def has_variants(self, reg):
854        return reg.name in self.variant_regs and not is_number(reg.name) and not is_number(reg.name[1:])
855
856    def dump(self):
857        enums = []
858        bitsets = []
859        regs = []
860        for e in self.file:
861            if isinstance(e, Enum):
862                enums.append(e)
863            elif isinstance(e, Bitset):
864                bitsets.append(e)
865            else:
866                regs.append(e)
867
868        for e in enums + bitsets + regs:
869            e.dump(self.has_variants(e))
870
871        self.dump_reg_usages()
872
873    def dump_regs_py(self):
874        regs = []
875        for e in self.file:
876            if isinstance(e, Reg):
877                regs.append(e)
878
879        for e in regs:
880            e.dump_py()
881
882    def dump_reg_variants(self, regname, variants):
883        if is_number(regname) or is_number(regname[1:]):
884            return
885        print("#ifdef __cplusplus")
886        print("struct __%s {" % regname)
887        # TODO be more clever.. we should probably figure out which
888        # fields have the same type in all variants (in which they
889        # appear) and stuff everything else in a variant specific
890        # sub-structure.
891        seen_fields = []
892        bit_size = 32
893        array = False
894        address = None
895        constexpr_mark = " CONSTEXPR"
896        for variant in variants.keys():
897            print("    /* %s fields: */" % variant)
898            reg = variants[variant]
899            bit_size = reg.bit_size
900            array = reg.array
901            for f in reg.bitset.fields:
902                fld_name = field_name(reg, f)
903                if fld_name in seen_fields:
904                    continue
905                seen_fields.append(fld_name)
906                name = fld_name.lower()
907                if f.type in ["address", "waddress"]:
908                    if address:
909                        continue
910                    address = f
911                    print("#ifndef TU_CS_H")
912                    tab_to("    __bo_type", "bo;")
913                    tab_to("    uint32_t", "bo_offset;")
914                    print("#endif")
915                    continue
916                type, val = f.ctype("var")
917                tab_to("    %s" % type, "%s;" % name)
918                if f.type == "float":
919                    constexpr_mark = ""
920        print("    /* fallback fields: */")
921        if bit_size == 64:
922            tab_to("    uint64_t", "unknown;")
923            tab_to("    uint64_t", "qword;")
924        else:
925            tab_to("    uint32_t", "unknown;")
926            tab_to("    uint32_t", "dword;")
927        print("};")
928        # TODO don't hardcode the varset enum name
929        varenum = "chip"
930        print("template <%s %s>" % (varenum, varenum.upper()))
931        print("static%s inline struct fd_reg_pair" % (constexpr_mark))
932        xtra = ""
933        xtravar = ""
934        if array:
935            xtra = "int __i, "
936            xtravar = "__i, "
937        print("__%s(%sstruct __%s fields) {" % (regname, xtra, regname))
938        for variant in variants.keys():
939            if "-" in variant:
940                start = variant[:variant.index("-")]
941                end = variant[variant.index("-") + 1:]
942                if end != "":
943                    print("  if ((%s >= %s) && (%s <= %s)) {" % (
944                        varenum.upper(), start, varenum.upper(), end))
945                else:
946                    print("  if (%s >= %s) {" % (varenum.upper(), start))
947            else:
948                print("  if (%s == %s) {" % (varenum.upper(), variant))
949            reg = variants[variant]
950            reg.dump_regpair_builder()
951            print("  } else")
952        print("    assert(!\"invalid variant\");")
953        print("  return (struct fd_reg_pair){};")
954        print("}")
955
956        if bit_size == 64:
957            skip = ", { .reg = 0 }"
958        else:
959            skip = ""
960
961        print("#define %s(VARIANT, %s...) __%s<VARIANT>(%s{__VA_ARGS__})%s" % (
962            regname, xtravar, regname, xtravar, skip))
963        print("#endif /* __cplusplus */")
964
965    def dump_structs(self):
966        for e in self.file:
967            e.dump_pack_struct(self.has_variants(e))
968
969        for regname in self.variant_regs:
970            self.dump_reg_variants(regname, self.variant_regs[regname])
971
972
973def dump_c(args, guard, func):
974    p = Parser()
975
976    try:
977        p.parse(args.rnn, args.xml, args.validate)
978    except Error as e:
979        print(e, file=sys.stderr)
980        exit(1)
981
982    print("#ifndef %s\n#define %s\n" % (guard, guard))
983
984    print("/* Autogenerated file, DO NOT EDIT manually! */")
985
986    print()
987    print("#ifdef __KERNEL__")
988    print("#include <linux/bug.h>")
989    print("#define assert(x) BUG_ON(!(x))")
990    print("#else")
991    print("#include <assert.h>")
992    print("#endif")
993    print()
994
995    print("#ifdef __cplusplus")
996    print("#define __struct_cast(X)")
997    print("#define CONSTEXPR constexpr")
998    print("#else")
999    print("#define __struct_cast(X) (struct X)")
1000    print("#define CONSTEXPR")
1001    print("#endif")
1002    print()
1003
1004    # TODO figure out what to do about fd_reg_stomp_allowed()
1005    # vs gcc.. for now only enable the warnings with clang:
1006    print("#if defined(__clang__) && !defined(FD_NO_DEPRECATED_PACK) && !defined(__KERNEL__)")
1007    print("#define __FD_DEPRECATED _Pragma (\"GCC warning \\\"Deprecated reg builder\\\"\")")
1008    print("#else")
1009    print("#define __FD_DEPRECATED")
1010    print("#endif")
1011    print()
1012
1013    func(p)
1014
1015    print("#endif /* %s */" % guard)
1016
1017
1018def dump_c_defines(args):
1019    guard = str.replace(os.path.basename(args.xml), '.', '_').upper()
1020    dump_c(args, guard, lambda p: p.dump())
1021
1022
1023def dump_c_pack_structs(args):
1024    guard = str.replace(os.path.basename(args.xml),
1025                        '.', '_').upper() + '_STRUCTS'
1026    dump_c(args, guard, lambda p: p.dump_structs())
1027
1028
1029def dump_perfcntrs(args):
1030    p = Parser()
1031
1032    try:
1033        p.parse(args.rnn, args.xml, args.validate)
1034    except Error as e:
1035        print(e, file=sys.stderr)
1036        exit(1)
1037
1038    perfcntrs = json.load(open(args.json, "r", encoding="utf-8"))
1039
1040    chip_type = p.enums['chip']
1041    chip = perfcntrs['chip']
1042    if not chip_type.has_name(chip):
1043        raise Error("Invalid chip: " + chip)
1044
1045    groups = perfcntrs['groups']
1046
1047    guard = "__" + chip + "_PERFCNTRS_"
1048    print("#ifndef %s\n#define %s\n" % (guard, guard))
1049    print("/* Autogenerated file, DO NOT EDIT manually! */")
1050    print()
1051    print("#ifdef __KERNEL__")
1052    print("#include \"msm_perfcntr.h\"")
1053    print("#endif")
1054    print()
1055
1056    def has_variant(variant):
1057        if variant is None:
1058            return True
1059        if "-" in variant:
1060            start = chip_type.value(variant[:variant.index("-")])
1061            end = chip_type.value(variant[variant.index("-") + 1:])
1062            chipn = chip_type.value(chip)
1063
1064            return (start is None or chipn >= start) and (end is None or chipn <= end)
1065        return chip == variant
1066
1067    # Split out arrays and regs for later access:
1068    arrays = {}
1069    regs = {}
1070    for e in p.file:
1071        if isinstance(e, Array) and has_variant(e.variant):
1072            arrays[e.local_name] = e
1073        if isinstance(e, Reg):
1074            regs[e.name] = e
1075
1076    # For variant regs, overwrite 'regs' entries with correct variant:
1077    for regname in p.variant_regs:
1078        for (variant, reg) in p.variant_regs[regname].items():
1079            if has_variant(variant):
1080                regs[regname] = reg
1081                break
1082
1083    for group in groups:
1084        name = group['name']
1085        name_low = name.lower()
1086        num = group['num']
1087        countable_type_name = group['countable_type']
1088
1089        if not countable_type_name in p.enums:
1090            raise Error("Invalid type: " + countable_type_name)
1091
1092        countable_type = p.enums[countable_type_name]
1093
1094        print("#ifndef __KERNEL__")
1095        print("static const struct fd_perfcntr_countable " + name_low + "_countables[] = {")
1096        for (name, value) in countable_type.values:
1097            # if the countable is prefixed with the chip, strip that:
1098            # (note: avoid py3.9 dependency for kernel)
1099            if name.startswith(chip + "_"):
1100                name = name[len(chip)+1:]
1101            print("   { \"" + name + "\", " + str(value) + " },")
1102        print("};")
1103        print("#endif")
1104
1105        print("static const struct fd_perfcntr_counter " + name_low + "_counters[] = {")
1106        for i in range(0, num):
1107            if "reserved" in group and i in group["reserved"]:
1108                continue
1109            def get_reg(name):
1110                # if reg has {} pattern, expand that first:
1111                name = name.format(i)
1112
1113                if name in arrays:
1114                    arr = arrays[name]
1115                    return arr.offset + (i * arr.stride)
1116
1117                if not name in regs:
1118                    raise Error("Invalid reg: " + name)
1119
1120                reg = regs[name]
1121                return reg.offset
1122
1123            def get_counter():
1124                # if the counter is <reg64> just a single "counter" value
1125                # should be specified in the json, but for legacy separate
1126                # hi/lo <reg32> pairs "counter_lo" and "counter_hi" should
1127                # be specified
1128                if "counter" in group:
1129                    counter = get_reg(group["counter"])
1130                    return [counter, counter+1]
1131                counter_lo = get_reg(group["counter_lo"])
1132                counter_hi = get_reg(group["counter_hi"])
1133                return [counter_lo, counter_hi]
1134
1135            (counter_lo, counter_hi) = get_counter()
1136            select = get_reg(group['select'])
1137
1138            select_offset = 0
1139            if "select_offset" in group:
1140                select_offset = int(group["select_offset"])
1141                select = select + select_offset
1142
1143            slice_select_str = ""
1144            if "slice_select" in group:
1145                slice_select = group["slice_select"]
1146                for reg in slice_select:
1147                    val = get_reg(reg) + select_offset
1148                    slice_select_str += "0x%04x, " % val
1149
1150            # TODO add support for things that need enable/clear regs
1151
1152            print("   { 0x%04x, {%s}, 0x%04x, 0x%04x }," % (select, slice_select_str, counter_lo, counter_hi))
1153        print("};")
1154
1155        print()
1156
1157    print("const struct fd_perfcntr_group " + chip.lower() + "_perfcntr_groups[] = {")
1158    for group in groups:
1159        name = group['name']
1160        name_low = name.lower()
1161        pipe = 'NONE'
1162        if 'pipe' in group:
1163            pipe = group['pipe']
1164
1165        print("   GROUP(\"%s\", PIPE_%s, %s_counters, %s_countables)," % (name, pipe, name_low, name_low))
1166
1167    print("};")
1168    print("const unsigned " + chip.lower() + "_num_perfcntr_groups = ARRAY_SIZE(" + chip.lower() + "_perfcntr_groups);")
1169
1170    print()
1171    print("#endif /* %s */" % guard)
1172
1173def dump_py_defines(args):
1174    p = Parser()
1175
1176    try:
1177        p.parse(args.rnn, args.xml, args.validate)
1178    except Error as e:
1179        print(e, file=sys.stderr)
1180        exit(1)
1181
1182    file_name = os.path.splitext(os.path.basename(args.xml))[0]
1183
1184    print("from enum import IntEnum")
1185    print("class %sRegs(IntEnum):" % file_name.upper())
1186
1187    os.path.basename(args.xml)
1188
1189    p.dump_regs_py()
1190
1191
1192def main():
1193    parser = argparse.ArgumentParser()
1194    parser.add_argument('--rnn', type=str, required=True)
1195    parser.add_argument('--xml', type=str, required=True)
1196    parser.add_argument('--validate', default=False, action='store_true')
1197    parser.add_argument('--no-validate', dest='validate', action='store_false')
1198
1199    subparsers = parser.add_subparsers()
1200    subparsers.required = True
1201
1202    parser_c_defines = subparsers.add_parser('c-defines')
1203    parser_c_defines.set_defaults(func=dump_c_defines)
1204
1205    parser_c_pack_structs = subparsers.add_parser('c-pack-structs')
1206    parser_c_pack_structs.set_defaults(func=dump_c_pack_structs)
1207
1208    parser_perfcntrs = subparsers.add_parser('perfcntrs')
1209    parser_perfcntrs.add_argument('--json', type=str, required=True)
1210    parser_perfcntrs.set_defaults(func=dump_perfcntrs)
1211
1212    parser_py_defines = subparsers.add_parser('py-defines')
1213    parser_py_defines.set_defaults(func=dump_py_defines)
1214
1215    args = parser.parse_args()
1216    args.func(args)
1217
1218
1219if __name__ == '__main__':
1220    main()
1221