xref: /linux/drivers/gpu/drm/msm/registers/gen_header.py (revision 3f1c07fc21c68bd3bd2df9d2c9441f6485e934d9)
1#!/usr/bin/python3
2#
3# Copyright © 2019-2024 Google, Inc.
4#
5# SPDX-License-Identifier: MIT
6
7import xml.parsers.expat
8import sys
9import os
10import collections
11import argparse
12import time
13import datetime
14
15class Error(Exception):
16	def __init__(self, message):
17		self.message = message
18
19class Enum(object):
20	def __init__(self, name):
21		self.name = name
22		self.values = []
23
24	def has_name(self, name):
25		for (n, value) in self.values:
26			if n == name:
27				return True
28		return False
29
30	def names(self):
31		return [n for (n, value) in self.values]
32
33	def dump(self, is_deprecated):
34		use_hex = False
35		for (name, value) in self.values:
36			if value > 0x1000:
37				use_hex = True
38
39		print("enum %s {" % self.name)
40		for (name, value) in self.values:
41			if use_hex:
42				print("\t%s = 0x%08x," % (name, value))
43			else:
44				print("\t%s = %d," % (name, value))
45		print("};\n")
46
47	def dump_pack_struct(self, is_deprecated):
48		pass
49
50class Field(object):
51	def __init__(self, name, low, high, shr, type, parser):
52		self.name = name
53		self.low = low
54		self.high = high
55		self.shr = shr
56		self.type = type
57
58		builtin_types = [ None, "a3xx_regid", "boolean", "uint", "hex", "int", "fixed", "ufixed", "float", "address", "waddress" ]
59
60		maxpos = parser.current_bitsize - 1
61
62		if low < 0 or low > maxpos:
63			raise parser.error("low attribute out of range: %d" % low)
64		if high < 0 or high > maxpos:
65			raise parser.error("high attribute out of range: %d" % high)
66		if high < low:
67			raise parser.error("low is greater than high: low=%d, high=%d" % (low, high))
68		if self.type == "boolean" and not low == high:
69			raise parser.error("booleans should be 1 bit fields")
70		elif self.type == "float" and not (high - low == 31 or high - low == 15):
71			raise parser.error("floats should be 16 or 32 bit fields")
72		elif self.type not in builtin_types and self.type not in parser.enums:
73			raise parser.error("unknown type '%s'" % self.type)
74
75	def ctype(self, var_name):
76		if self.type is None:
77			type = "uint32_t"
78			val = var_name
79		elif self.type == "boolean":
80			type = "bool"
81			val = var_name
82		elif self.type == "uint" or self.type == "hex" or self.type == "a3xx_regid":
83			type = "uint32_t"
84			val = var_name
85		elif self.type == "int":
86			type = "int32_t"
87			val = var_name
88		elif self.type == "fixed":
89			type = "float"
90			val = "((int32_t)(%s * %d.0))" % (var_name, 1 << self.radix)
91		elif self.type == "ufixed":
92			type = "float"
93			val = "((uint32_t)(%s * %d.0))" % (var_name, 1 << self.radix)
94		elif self.type == "float" and self.high - self.low == 31:
95			type = "float"
96			val = "fui(%s)" % var_name
97		elif self.type == "float" and self.high - self.low == 15:
98			type = "float"
99			val = "_mesa_float_to_half(%s)" % var_name
100		elif self.type in [ "address", "waddress" ]:
101			type = "uint64_t"
102			val = var_name
103		else:
104			type = "enum %s" % self.type
105			val = var_name
106
107		if self.shr > 0:
108			val = "(%s >> %d)" % (val, self.shr)
109
110		return (type, val)
111
112def tab_to(name, value):
113	tab_count = (68 - (len(name) & ~7)) // 8
114	if tab_count <= 0:
115		tab_count = 1
116	print(name + ('\t' * tab_count) + value)
117
118def mask(low, high):
119	return ((0xffffffffffffffff >> (64 - (high + 1 - low))) << low)
120
121def field_name(reg, f):
122	if f.name:
123		name = f.name.lower()
124	else:
125		# We hit this path when a reg is defined with no bitset fields, ie.
126		# 	<reg32 offset="0x88db" name="RB_RESOLVE_SYSTEM_BUFFER_ARRAY_PITCH" low="0" high="28" shr="6" type="uint"/>
127		name = reg.name.lower()
128
129	if (name in [ "double", "float", "int" ]) or not (name[0].isalpha()):
130			name = "_" + name
131
132	return name
133
134# indices - array of (ctype, stride, __offsets_NAME)
135def indices_varlist(indices):
136	return ", ".join(["i%d" % i for i in range(len(indices))])
137
138def indices_prototype(indices):
139	return ", ".join(["%s i%d" % (ctype, idx)
140			for (idx, (ctype, stride, offset)) in  enumerate(indices)])
141
142def indices_strides(indices):
143	return " + ".join(["0x%x*i%d" % (stride, idx)
144					if stride else
145					"%s(i%d)" % (offset, idx)
146			for (idx, (ctype, stride, offset)) in  enumerate(indices)])
147
148def is_number(str):
149	try:
150		int(str)
151		return True
152	except ValueError:
153		return False
154
155def sanitize_variant(variant):
156	if variant and "-" in variant:
157		return variant[:variant.index("-")]
158	return variant
159
160class Bitset(object):
161	def __init__(self, name, template):
162		self.name = name
163		self.inline = False
164		self.reg = None
165		if template:
166			self.fields = template.fields[:]
167		else:
168			self.fields = []
169
170	# Get address field if there is one in the bitset, else return None:
171	def get_address_field(self):
172		for f in self.fields:
173			if f.type in [ "address", "waddress" ]:
174				return f
175		return None
176
177	def dump_regpair_builder(self, reg):
178		print("#ifndef NDEBUG")
179		known_mask = 0
180		for f in self.fields:
181			known_mask |= mask(f.low, f.high)
182			if f.type in [ "boolean", "address", "waddress" ]:
183				continue
184			type, val = f.ctype("fields.%s" % field_name(reg, f))
185			print("    assert((%-40s & 0x%08x) == 0);" % (val, 0xffffffff ^ mask(0 , f.high - f.low)))
186		print("    assert((%-40s & 0x%08x) == 0);" % ("fields.unknown", known_mask))
187		print("#endif\n")
188
189		print("    return (struct fd_reg_pair) {")
190		print("        .reg = (uint32_t)%s," % reg.reg_offset())
191		print("        .value =")
192		cast = "(uint64_t)" if reg.bit_size == 64 else ""
193		for f in self.fields:
194			if f.type in [ "address", "waddress" ]:
195				continue
196			else:
197				type, val = f.ctype("fields.%s" % field_name(reg, f))
198				print("            (%s%-40s << %2d) |" % (cast, val, f.low))
199		value_name = "dword"
200		if reg.bit_size == 64:
201			value_name = "qword"
202		print("            fields.unknown | fields.%s," % (value_name,))
203
204		address = self.get_address_field()
205		if address:
206			print("        .bo = fields.bo,")
207			print("        .is_address = true,")
208			if f.type == "waddress":
209				print("        .bo_write = true,")
210			print("        .bo_offset = fields.bo_offset,")
211			print("        .bo_shift = %d," % address.shr)
212			print("        .bo_low = %d," % address.low)
213
214		print("    };")
215
216	def dump_pack_struct(self, is_deprecated, reg=None):
217		if not reg:
218			return
219
220		prefix = reg.full_name
221
222		print("struct %s {" % prefix)
223		for f in self.fields:
224			if f.type in [ "address", "waddress" ]:
225				tab_to("    __bo_type", "bo;")
226				tab_to("    uint32_t", "bo_offset;")
227				continue
228			name = field_name(reg, f)
229
230			type, val = f.ctype("var")
231
232			tab_to("    %s" % type, "%s;" % name)
233		if reg.bit_size == 64:
234			tab_to("    uint64_t", "unknown;")
235			tab_to("    uint64_t", "qword;")
236		else:
237			tab_to("    uint32_t", "unknown;")
238			tab_to("    uint32_t", "dword;")
239		print("};\n")
240
241		depcrstr = ""
242		if is_deprecated:
243			depcrstr = " FD_DEPRECATED"
244		if reg.array:
245			print("static inline%s struct fd_reg_pair\npack_%s(uint32_t __i, struct %s fields)\n{" %
246				  (depcrstr, prefix, prefix))
247		else:
248			print("static inline%s struct fd_reg_pair\npack_%s(struct %s fields)\n{" %
249				  (depcrstr, prefix, prefix))
250
251		self.dump_regpair_builder(reg)
252
253		print("\n}\n")
254
255		if self.get_address_field():
256			skip = ", { .reg = 0 }"
257		else:
258			skip = ""
259
260		if reg.array:
261			print("#define %s(__i, ...) pack_%s(__i, __struct_cast(%s) { __VA_ARGS__ })%s\n" %
262				  (prefix, prefix, prefix, skip))
263		else:
264			print("#define %s(...) pack_%s(__struct_cast(%s) { __VA_ARGS__ })%s\n" %
265				  (prefix, prefix, prefix, skip))
266
267
268	def dump(self, is_deprecated, prefix=None, reg=None):
269		if prefix is None:
270			prefix = self.name
271		reg64 = reg and self.reg and self.reg.bit_size == 64
272		if reg64:
273			print("static inline uint32_t %s_LO(uint32_t val)\n{" % prefix)
274			print("\treturn val;\n}")
275			print("static inline uint32_t %s_HI(uint32_t val)\n{" % prefix)
276			print("\treturn val;\n}")
277		for f in self.fields:
278			if f.name:
279				name = prefix + "_" + f.name
280			else:
281				name = prefix
282
283			if not f.name and f.low == 0 and f.shr == 0 and f.type not in ["float", "fixed", "ufixed"]:
284				pass
285			elif f.type == "boolean" or (f.type is None and f.low == f.high):
286				tab_to("#define %s" % name, "0x%08x" % (1 << f.low))
287			else:
288				typespec = "ull" if reg64 else "u"
289				tab_to("#define %s__MASK" % name, "0x%08x%s" % (mask(f.low, f.high), typespec))
290				tab_to("#define %s__SHIFT" % name, "%d" % f.low)
291				type, val = f.ctype("val")
292				ret_type = "uint64_t" if reg64 else "uint32_t"
293				cast = "(uint64_t)" if reg64 else ""
294
295				print("static inline %s %s(%s val)\n{" % (ret_type, name, type))
296				if f.shr > 0:
297					print("\tassert(!(val & 0x%x));" % mask(0, f.shr - 1))
298				print("\treturn (%s(%s) << %s__SHIFT) & %s__MASK;\n}" % (cast, val, name, name))
299		print()
300
301class Array(object):
302	def __init__(self, attrs, domain, variant, parent, index_type):
303		if "name" in attrs:
304			self.local_name = attrs["name"]
305		else:
306			self.local_name = ""
307		self.domain = domain
308		self.variant = variant
309		self.parent = parent
310		self.children = []
311		if self.parent:
312			self.name = self.parent.name + "_" + self.local_name
313		else:
314			self.name = self.local_name
315		if "offsets" in attrs:
316			self.offsets = map(lambda i: "0x%08x" % int(i, 0), attrs["offsets"].split(","))
317			self.fixed_offsets = True
318		elif "doffsets" in attrs:
319			self.offsets = map(lambda s: "(%s)" % s , attrs["doffsets"].split(","))
320			self.fixed_offsets = True
321		else:
322			self.offset = int(attrs["offset"], 0)
323			self.stride = int(attrs["stride"], 0)
324			self.fixed_offsets = False
325		if "index" in attrs:
326			self.index_type = index_type
327		else:
328			self.index_type = None
329		self.length = int(attrs["length"], 0)
330		if "usage" in attrs:
331			self.usages = attrs["usage"].split(',')
332		else:
333			self.usages = None
334
335	def index_ctype(self):
336		if not self.index_type:
337			return "uint32_t"
338		else:
339			return "enum %s" % self.index_type.name
340
341	# Generate array of (ctype, stride, __offsets_NAME)
342	def indices(self):
343		if self.parent:
344			indices = self.parent.indices()
345		else:
346			indices = []
347		if self.length != 1:
348			if self.fixed_offsets:
349				indices.append((self.index_ctype(), None, "__offset_%s" % self.local_name))
350			else:
351				indices.append((self.index_ctype(), self.stride, None))
352		return indices
353
354	def total_offset(self):
355		offset = 0
356		if not self.fixed_offsets:
357			offset += self.offset
358		if self.parent:
359			offset += self.parent.total_offset()
360		return offset
361
362	def dump(self, is_deprecated):
363		depcrstr = ""
364		if is_deprecated:
365			depcrstr = " FD_DEPRECATED"
366		proto = indices_varlist(self.indices())
367		strides = indices_strides(self.indices())
368		array_offset = self.total_offset()
369		if self.fixed_offsets:
370			print("static inline%s uint32_t __offset_%s(%s idx)" % (depcrstr, self.local_name, self.index_ctype()))
371			print("{\n\tswitch (idx) {")
372			if self.index_type:
373				for val, offset in zip(self.index_type.names(), self.offsets):
374					print("\t\tcase %s: return %s;" % (val, offset))
375			else:
376				for idx, offset in enumerate(self.offsets):
377					print("\t\tcase %d: return %s;" % (idx, offset))
378			print("\t\tdefault: return INVALID_IDX(idx);")
379			print("\t}\n}")
380		if proto == '':
381			tab_to("#define REG_%s_%s" % (self.domain, self.name), "0x%08x\n" % array_offset)
382		else:
383			tab_to("#define REG_%s_%s(%s)" % (self.domain, self.name, proto), "(0x%08x + %s )\n" % (array_offset, strides))
384
385	def dump_pack_struct(self, is_deprecated):
386		pass
387
388	def dump_regpair_builder(self):
389		pass
390
391class Reg(object):
392	def __init__(self, attrs, domain, array, bit_size):
393		self.name = attrs["name"]
394		self.domain = domain
395		self.array = array
396		self.offset = int(attrs["offset"], 0)
397		self.type = None
398		self.bit_size = bit_size
399		if array:
400			self.name = array.name + "_" + self.name
401			array.children.append(self)
402		self.full_name = self.domain + "_" + self.name
403		if "stride" in attrs:
404			self.stride = int(attrs["stride"], 0)
405			self.length = int(attrs["length"], 0)
406		else:
407			self.stride = None
408			self.length = None
409
410	# Generate array of (ctype, stride, __offsets_NAME)
411	def indices(self):
412		if self.array:
413			indices = self.array.indices()
414		else:
415			indices = []
416		if self.stride:
417			indices.append(("uint32_t", self.stride, None))
418		return indices
419
420	def total_offset(self):
421		if self.array:
422			return self.array.total_offset() + self.offset
423		else:
424			return self.offset
425
426	def reg_offset(self):
427		if self.array:
428			offset = self.array.offset + self.offset
429			return "(0x%08x + 0x%x*__i)" % (offset, self.array.stride)
430		return "0x%08x" % self.offset
431
432	def dump(self, is_deprecated):
433		depcrstr = ""
434		if is_deprecated:
435			depcrstr = " FD_DEPRECATED "
436		proto = indices_prototype(self.indices())
437		strides = indices_strides(self.indices())
438		offset = self.total_offset()
439		if proto == '':
440			tab_to("#define REG_%s" % self.full_name, "0x%08x" % offset)
441		else:
442			print("static inline%s uint32_t REG_%s(%s) { return 0x%08x + %s; }" % (depcrstr, self.full_name, proto, offset, strides))
443
444		if self.bitset.inline:
445			self.bitset.dump(is_deprecated, self.full_name, self)
446		print("")
447
448	def dump_pack_struct(self, is_deprecated):
449		if self.bitset.inline:
450			self.bitset.dump_pack_struct(is_deprecated, self)
451
452	def dump_regpair_builder(self):
453		self.bitset.dump_regpair_builder(self)
454
455	def dump_py(self):
456		print("\tREG_%s = 0x%08x" % (self.full_name, self.offset))
457
458
459class Parser(object):
460	def __init__(self):
461		self.current_array = None
462		self.current_domain = None
463		self.current_prefix = None
464		self.current_prefix_type = None
465		self.current_stripe = None
466		self.current_bitset = None
467		self.current_bitsize = 32
468		# The varset attribute on the domain specifies the enum which
469		# specifies all possible hw variants:
470		self.current_varset = None
471		# Regs that have multiple variants.. we only generated the C++
472		# template based struct-packers for these
473		self.variant_regs = {}
474		# Information in which contexts regs are used, to be used in
475		# debug options
476		self.usage_regs = collections.defaultdict(list)
477		self.bitsets = {}
478		self.enums = {}
479		self.variants = set()
480		self.file = []
481		self.xml_files = []
482
483	def error(self, message):
484		parser, filename = self.stack[-1]
485		return Error("%s:%d:%d: %s" % (filename, parser.CurrentLineNumber, parser.CurrentColumnNumber, message))
486
487	def prefix(self, variant=None):
488		if self.current_prefix_type == "variant" and variant:
489			return sanitize_variant(variant)
490		elif self.current_stripe:
491			return self.current_stripe + "_" + self.current_domain
492		elif self.current_prefix:
493			return self.current_prefix + "_" + self.current_domain
494		else:
495			return self.current_domain
496
497	def parse_field(self, name, attrs):
498		try:
499			if "pos" in attrs:
500				high = low = int(attrs["pos"], 0)
501			elif "high" in attrs and "low" in attrs:
502				high = int(attrs["high"], 0)
503				low = int(attrs["low"], 0)
504			else:
505				low = 0
506				high = self.current_bitsize - 1
507
508			if "type" in attrs:
509				type = attrs["type"]
510			else:
511				type = None
512
513			if "shr" in attrs:
514				shr = int(attrs["shr"], 0)
515			else:
516				shr = 0
517
518			b = Field(name, low, high, shr, type, self)
519
520			if type == "fixed" or type == "ufixed":
521				b.radix = int(attrs["radix"], 0)
522
523			self.current_bitset.fields.append(b)
524		except ValueError as e:
525			raise self.error(e)
526
527	def parse_varset(self, attrs):
528		# Inherit the varset from the enclosing domain if not overriden:
529		varset = self.current_varset
530		if "varset" in attrs:
531			varset = self.enums[attrs["varset"]]
532		return varset
533
534	def parse_variants(self, attrs):
535		if "variants" not in attrs:
536				return None
537
538		variant = attrs["variants"].split(",")[0]
539		varset = self.parse_varset(attrs)
540
541		if "-" in variant:
542			# if we have a range, validate that both the start and end
543			# of the range are valid enums:
544			start = variant[:variant.index("-")]
545			end = variant[variant.index("-") + 1:]
546			assert varset.has_name(start)
547			if end != "":
548				assert varset.has_name(end)
549		else:
550			assert varset.has_name(variant)
551
552		return variant
553
554	def add_all_variants(self, reg, attrs, parent_variant):
555		# TODO this should really handle *all* variants, including dealing
556		# with open ended ranges (ie. "A2XX,A4XX-") (we have the varset
557		# enum now to make that possible)
558		variant = self.parse_variants(attrs)
559		if not variant:
560			variant = parent_variant
561
562		if reg.name not in self.variant_regs:
563			self.variant_regs[reg.name] = {}
564		else:
565			# All variants must be same size:
566			v = next(iter(self.variant_regs[reg.name]))
567			assert self.variant_regs[reg.name][v].bit_size == reg.bit_size
568
569		self.variant_regs[reg.name][variant] = reg
570
571	def add_all_usages(self, reg, usages):
572		if not usages:
573			return
574
575		for usage in usages:
576			self.usage_regs[usage].append(reg)
577
578		self.variants.add(reg.domain)
579
580	def do_validate(self, schemafile):
581		if not self.validate:
582			return
583
584		try:
585			from lxml import etree
586
587			parser, filename = self.stack[-1]
588			dirname = os.path.dirname(filename)
589
590			# we expect this to look like <namespace url> schema.xsd.. I think
591			# technically it is supposed to be just a URL, but that doesn't
592			# quite match up to what we do.. Just skip over everything up to
593			# and including the first whitespace character:
594			schemafile = schemafile[schemafile.rindex(" ")+1:]
595
596			# this is a bit cheezy, but the xml file to validate could be
597			# in a child director, ie. we don't really know where the schema
598			# file is, the way the rnn C code does.  So if it doesn't exist
599			# just look one level up
600			if not os.path.exists(dirname + "/" + schemafile):
601				schemafile = "../" + schemafile
602
603			if not os.path.exists(dirname + "/" + schemafile):
604				raise self.error("Cannot find schema for: " + filename)
605
606			xmlschema_doc = etree.parse(dirname + "/" + schemafile)
607			xmlschema = etree.XMLSchema(xmlschema_doc)
608
609			xml_doc = etree.parse(filename)
610			if not xmlschema.validate(xml_doc):
611				error_str = str(xmlschema.error_log.filter_from_errors()[0])
612				raise self.error("Schema validation failed for: " + filename + "\n" + error_str)
613		except ImportError as e:
614			print("lxml not found, skipping validation", file=sys.stderr)
615
616	def do_parse(self, filename):
617		filepath = os.path.abspath(filename)
618		if filepath in self.xml_files:
619			return
620		self.xml_files.append(filepath)
621		file = open(filename, "rb")
622		parser = xml.parsers.expat.ParserCreate()
623		self.stack.append((parser, filename))
624		parser.StartElementHandler = self.start_element
625		parser.EndElementHandler = self.end_element
626		parser.CharacterDataHandler = self.character_data
627		parser.buffer_text = True
628		parser.ParseFile(file)
629		self.stack.pop()
630		file.close()
631
632	def parse(self, rnn_path, filename, validate):
633		self.path = rnn_path
634		self.stack = []
635		self.validate = validate
636		self.do_parse(filename)
637
638	def parse_reg(self, attrs, bit_size):
639		self.current_bitsize = bit_size
640		if "type" in attrs and attrs["type"] in self.bitsets:
641			bitset = self.bitsets[attrs["type"]]
642			if bitset.inline:
643				self.current_bitset = Bitset(attrs["name"], bitset)
644				self.current_bitset.inline = True
645			else:
646				self.current_bitset = bitset
647		else:
648			self.current_bitset = Bitset(attrs["name"], None)
649			self.current_bitset.inline = True
650			if "type" in attrs:
651				self.parse_field(None, attrs)
652
653		variant = self.parse_variants(attrs)
654		if not variant and self.current_array:
655			variant = self.current_array.variant
656
657		self.current_reg = Reg(attrs, self.prefix(variant), self.current_array, bit_size)
658		self.current_reg.bitset = self.current_bitset
659		self.current_bitset.reg = self.current_reg
660
661		if len(self.stack) == 1:
662			self.file.append(self.current_reg)
663
664		if variant is not None:
665			self.add_all_variants(self.current_reg, attrs, variant)
666
667		usages = None
668		if "usage" in attrs:
669			usages = attrs["usage"].split(',')
670		elif self.current_array:
671			usages = self.current_array.usages
672
673		self.add_all_usages(self.current_reg, usages)
674
675	def start_element(self, name, attrs):
676		self.cdata = ""
677		if name == "import":
678			filename = attrs["file"]
679			self.do_parse(os.path.join(self.path, filename))
680		elif name == "domain":
681			self.current_domain = attrs["name"]
682			if "prefix" in attrs:
683				self.current_prefix = sanitize_variant(self.parse_variants(attrs))
684				self.current_prefix_type = attrs["prefix"]
685			else:
686				self.current_prefix = None
687				self.current_prefix_type = None
688			if "varset" in attrs:
689				self.current_varset = self.enums[attrs["varset"]]
690		elif name == "stripe":
691			self.current_stripe = sanitize_variant(self.parse_variants(attrs))
692		elif name == "enum":
693			self.current_enum_value = 0
694			self.current_enum = Enum(attrs["name"])
695			self.enums[attrs["name"]] = self.current_enum
696			if len(self.stack) == 1:
697				self.file.append(self.current_enum)
698		elif name == "value":
699			if "value" in attrs:
700				value = int(attrs["value"], 0)
701			else:
702				value = self.current_enum_value
703			self.current_enum.values.append((attrs["name"], value))
704		elif name == "reg32":
705			self.parse_reg(attrs, 32)
706		elif name == "reg64":
707			self.parse_reg(attrs, 64)
708		elif name == "array":
709			self.current_bitsize = 32
710			variant = self.parse_variants(attrs)
711			index_type = self.enums[attrs["index"]] if "index" in attrs else None
712			self.current_array = Array(attrs, self.prefix(variant), variant, self.current_array, index_type)
713			if len(self.stack) == 1:
714				self.file.append(self.current_array)
715		elif name == "bitset":
716			self.current_bitset = Bitset(attrs["name"], None)
717			if "inline" in attrs and attrs["inline"] == "yes":
718				self.current_bitset.inline = True
719			self.bitsets[self.current_bitset.name] = self.current_bitset
720			if len(self.stack) == 1 and not self.current_bitset.inline:
721				self.file.append(self.current_bitset)
722		elif name == "bitfield" and self.current_bitset:
723			self.parse_field(attrs["name"], attrs)
724		elif name == "database":
725			self.do_validate(attrs["xsi:schemaLocation"])
726
727	def end_element(self, name):
728		if name == "domain":
729			self.current_domain = None
730			self.current_prefix = None
731			self.current_prefix_type = None
732		elif name == "stripe":
733			self.current_stripe = None
734		elif name == "bitset":
735			self.current_bitset = None
736		elif name == "reg32":
737			self.current_reg = None
738		elif name == "array":
739			# if the array has no Reg children, push an implicit reg32:
740			if len(self.current_array.children) == 0:
741				attrs = {
742					"name": "REG",
743					"offset": "0",
744				}
745				self.parse_reg(attrs, 32)
746			self.current_array = self.current_array.parent
747		elif name == "enum":
748			self.current_enum = None
749
750	def character_data(self, data):
751		self.cdata += data
752
753	def dump_reg_usages(self):
754		d = collections.defaultdict(list)
755		for usage, regs in self.usage_regs.items():
756			for reg in regs:
757				variants = self.variant_regs.get(reg.name)
758				if variants:
759					for variant, vreg in variants.items():
760						if reg == vreg:
761							d[(usage, sanitize_variant(variant))].append(reg)
762				else:
763					for variant in self.variants:
764						d[(usage, sanitize_variant(variant))].append(reg)
765
766		print("#ifdef __cplusplus")
767
768		for usage, regs in self.usage_regs.items():
769			print("template<chip CHIP> constexpr inline uint16_t %s_REGS[] = {};" % (usage.upper()))
770
771		for (usage, variant), regs in d.items():
772			offsets = []
773
774			for reg in regs:
775				if reg.array:
776					for i in range(reg.array.length):
777						offsets.append(reg.array.offset + reg.offset + i * reg.array.stride)
778						if reg.bit_size == 64:
779							offsets.append(offsets[-1] + 1)
780				else:
781					offsets.append(reg.offset)
782					if reg.bit_size == 64:
783						offsets.append(offsets[-1] + 1)
784
785			offsets.sort()
786
787			print("template<> constexpr inline uint16_t %s_REGS<%s>[] = {" % (usage.upper(), variant))
788			for offset in offsets:
789				print("\t%s," % hex(offset))
790			print("};")
791
792		print("#endif")
793
794	def has_variants(self, reg):
795		return reg.name in self.variant_regs and not is_number(reg.name) and not is_number(reg.name[1:])
796
797	def dump(self):
798		enums = []
799		bitsets = []
800		regs = []
801		for e in self.file:
802			if isinstance(e, Enum):
803				enums.append(e)
804			elif isinstance(e, Bitset):
805				bitsets.append(e)
806			else:
807				regs.append(e)
808
809		for e in enums + bitsets + regs:
810			e.dump(self.has_variants(e))
811
812		self.dump_reg_usages()
813
814
815	def dump_regs_py(self):
816		regs = []
817		for e in self.file:
818			if isinstance(e, Reg):
819				regs.append(e)
820
821		for e in regs:
822			e.dump_py()
823
824
825	def dump_reg_variants(self, regname, variants):
826		if is_number(regname) or is_number(regname[1:]):
827			return
828		print("#ifdef __cplusplus")
829		print("struct __%s {" % regname)
830		# TODO be more clever.. we should probably figure out which
831		# fields have the same type in all variants (in which they
832		# appear) and stuff everything else in a variant specific
833		# sub-structure.
834		seen_fields = []
835		bit_size = 32
836		array = False
837		address = None
838		for variant in variants.keys():
839			print("    /* %s fields: */" % variant)
840			reg = variants[variant]
841			bit_size = reg.bit_size
842			array = reg.array
843			for f in reg.bitset.fields:
844				fld_name = field_name(reg, f)
845				if fld_name in seen_fields:
846					continue
847				seen_fields.append(fld_name)
848				name = fld_name.lower()
849				if f.type in [ "address", "waddress" ]:
850					if address:
851						continue
852					address = f
853					tab_to("    __bo_type", "bo;")
854					tab_to("    uint32_t", "bo_offset;")
855					continue
856				type, val = f.ctype("var")
857				tab_to("    %s" %type, "%s;" %name)
858		print("    /* fallback fields: */")
859		if bit_size == 64:
860			tab_to("    uint64_t", "unknown;")
861			tab_to("    uint64_t", "qword;")
862		else:
863			tab_to("    uint32_t", "unknown;")
864			tab_to("    uint32_t", "dword;")
865		print("};")
866		# TODO don't hardcode the varset enum name
867		varenum = "chip"
868		print("template <%s %s>" % (varenum, varenum.upper()))
869		print("static inline struct fd_reg_pair")
870		xtra = ""
871		xtravar = ""
872		if array:
873			xtra = "int __i, "
874			xtravar = "__i, "
875		print("__%s(%sstruct __%s fields) {" % (regname, xtra, regname))
876		for variant in variants.keys():
877			if "-" in variant:
878				start = variant[:variant.index("-")]
879				end = variant[variant.index("-") + 1:]
880				if end != "":
881					print("  if ((%s >= %s) && (%s <= %s)) {" % (varenum.upper(), start, varenum.upper(), end))
882				else:
883					print("  if (%s >= %s) {" % (varenum.upper(), start))
884			else:
885				print("  if (%s == %s) {" % (varenum.upper(), variant))
886			reg = variants[variant]
887			reg.dump_regpair_builder()
888			print("  } else")
889		print("    assert(!\"invalid variant\");")
890		print("  return (struct fd_reg_pair){};")
891		print("}")
892
893		if bit_size == 64:
894			skip = ", { .reg = 0 }"
895		else:
896			skip = ""
897
898		print("#define %s(VARIANT, %s...) __%s<VARIANT>(%s{__VA_ARGS__})%s" % (regname, xtravar, regname, xtravar, skip))
899		print("#endif /* __cplusplus */")
900
901	def dump_structs(self):
902		for e in self.file:
903			e.dump_pack_struct(self.has_variants(e))
904
905		for regname in self.variant_regs:
906			self.dump_reg_variants(regname, self.variant_regs[regname])
907
908
909def dump_c(args, guard, func):
910	p = Parser()
911
912	try:
913		p.parse(args.rnn, args.xml, args.validate)
914	except Error as e:
915		print(e, file=sys.stderr)
916		exit(1)
917
918	print("#ifndef %s\n#define %s\n" % (guard, guard))
919
920	print("/* Autogenerated file, DO NOT EDIT manually! */")
921
922	print()
923	print("#ifdef __KERNEL__")
924	print("#include <linux/bug.h>")
925	print("#define assert(x) BUG_ON(!(x))")
926	print("#else")
927	print("#include <assert.h>")
928	print("#endif")
929	print()
930
931	print("#ifdef __cplusplus")
932	print("#define __struct_cast(X)")
933	print("#else")
934	print("#define __struct_cast(X) (struct X)")
935	print("#endif")
936	print()
937
938	print("#ifndef FD_NO_DEPRECATED_PACK")
939	print("#define FD_DEPRECATED __attribute__((deprecated))")
940	print("#else")
941	print("#define FD_DEPRECATED")
942	print("#endif")
943	print()
944
945	func(p)
946
947	print()
948	print("#undef FD_DEPRECATED")
949	print()
950
951	print("#endif /* %s */" % guard)
952
953
954def dump_c_defines(args):
955	guard = str.replace(os.path.basename(args.xml), '.', '_').upper()
956	dump_c(args, guard, lambda p: p.dump())
957
958
959def dump_c_pack_structs(args):
960	guard = str.replace(os.path.basename(args.xml), '.', '_').upper() + '_STRUCTS'
961	dump_c(args, guard, lambda p: p.dump_structs())
962
963
964def dump_py_defines(args):
965	p = Parser()
966
967	try:
968		p.parse(args.rnn, args.xml, args.validate)
969	except Error as e:
970		print(e, file=sys.stderr)
971		exit(1)
972
973	file_name = os.path.splitext(os.path.basename(args.xml))[0]
974
975	print("from enum import IntEnum")
976	print("class %sRegs(IntEnum):" % file_name.upper())
977
978	os.path.basename(args.xml)
979
980	p.dump_regs_py()
981
982
983def main():
984	parser = argparse.ArgumentParser()
985	parser.add_argument('--rnn', type=str, required=True)
986	parser.add_argument('--xml', type=str, required=True)
987	parser.add_argument('--validate', default=False, action='store_true')
988	parser.add_argument('--no-validate', dest='validate', action='store_false')
989
990	subparsers = parser.add_subparsers()
991	subparsers.required = True
992
993	parser_c_defines = subparsers.add_parser('c-defines')
994	parser_c_defines.set_defaults(func=dump_c_defines)
995
996	parser_c_pack_structs = subparsers.add_parser('c-pack-structs')
997	parser_c_pack_structs.set_defaults(func=dump_c_pack_structs)
998
999	parser_py_defines = subparsers.add_parser('py-defines')
1000	parser_py_defines.set_defaults(func=dump_py_defines)
1001
1002	args = parser.parse_args()
1003	args.func(args)
1004
1005
1006if __name__ == '__main__':
1007	main()
1008