xref: /linux/tools/net/sunrpc/xdrgen/xdr_ast.py (revision cdd30ebb1b9f36159d66f088b61aee264e649d7a)
1#!/usr/bin/env python3
2# ex: set filetype=python:
3
4"""Define and implement the Abstract Syntax Tree for the XDR language."""
5
6import sys
7from typing import List
8from dataclasses import dataclass
9
10from lark import ast_utils, Transformer
11from lark.tree import Meta
12
13this_module = sys.modules[__name__]
14
15big_endian = []
16excluded_apis = []
17header_name = "none"
18public_apis = []
19structs = set()
20pass_by_reference = set()
21
22constants = {}
23
24
25def xdr_quadlen(val: str) -> int:
26    """Return integer XDR width of an XDR type"""
27    if val in constants:
28        octets = constants[val]
29    else:
30        octets = int(val)
31    return int((octets + 3) / 4)
32
33
34symbolic_widths = {
35    "void": ["XDR_void"],
36    "bool": ["XDR_bool"],
37    "int": ["XDR_int"],
38    "unsigned_int": ["XDR_unsigned_int"],
39    "long": ["XDR_long"],
40    "unsigned_long": ["XDR_unsigned_long"],
41    "hyper": ["XDR_hyper"],
42    "unsigned_hyper": ["XDR_unsigned_hyper"],
43}
44
45# Numeric XDR widths are tracked in a dictionary that is keyed
46# by type_name because sometimes a caller has nothing more than
47# the type_name to use to figure out the numeric width.
48max_widths = {
49    "void": 0,
50    "bool": 1,
51    "int": 1,
52    "unsigned_int": 1,
53    "long": 1,
54    "unsigned_long": 1,
55    "hyper": 2,
56    "unsigned_hyper": 2,
57}
58
59
60@dataclass
61class _XdrAst(ast_utils.Ast):
62    """Base class for the XDR abstract syntax tree"""
63
64
65@dataclass
66class _XdrIdentifier(_XdrAst):
67    """Corresponds to 'identifier' in the XDR language grammar"""
68
69    symbol: str
70
71
72@dataclass
73class _XdrValue(_XdrAst):
74    """Corresponds to 'value' in the XDR language grammar"""
75
76    value: str
77
78
79@dataclass
80class _XdrConstantValue(_XdrAst):
81    """Corresponds to 'constant' in the XDR language grammar"""
82
83    value: int
84
85
86@dataclass
87class _XdrTypeSpecifier(_XdrAst):
88    """Corresponds to 'type_specifier' in the XDR language grammar"""
89
90    type_name: str
91    c_classifier: str = ""
92
93
94@dataclass
95class _XdrDefinedType(_XdrTypeSpecifier):
96    """Corresponds to a type defined by the input specification"""
97
98    def symbolic_width(self) -> List:
99        """Return list containing XDR width of type's components"""
100        return [get_header_name().upper() + "_" + self.type_name + "_sz"]
101
102    def __post_init__(self):
103        if self.type_name in structs:
104            self.c_classifier = "struct "
105        symbolic_widths[self.type_name] = self.symbolic_width()
106
107
108@dataclass
109class _XdrBuiltInType(_XdrTypeSpecifier):
110    """Corresponds to a built-in XDR type"""
111
112    def symbolic_width(self) -> List:
113        """Return list containing XDR width of type's components"""
114        return symbolic_widths[self.type_name]
115
116
117@dataclass
118class _XdrDeclaration(_XdrAst):
119    """Base class of XDR type declarations"""
120
121
122@dataclass
123class _XdrFixedLengthOpaque(_XdrDeclaration):
124    """A fixed-length opaque declaration"""
125
126    name: str
127    size: str
128    template: str = "fixed_length_opaque"
129
130    def max_width(self) -> int:
131        """Return width of type in XDR_UNITS"""
132        return xdr_quadlen(self.size)
133
134    def symbolic_width(self) -> List:
135        """Return list containing XDR width of type's components"""
136        return ["XDR_QUADLEN(" + self.size + ")"]
137
138    def __post_init__(self):
139        max_widths[self.name] = self.max_width()
140        symbolic_widths[self.name] = self.symbolic_width()
141
142
143@dataclass
144class _XdrVariableLengthOpaque(_XdrDeclaration):
145    """A variable-length opaque declaration"""
146
147    name: str
148    maxsize: str
149    template: str = "variable_length_opaque"
150
151    def max_width(self) -> int:
152        """Return width of type in XDR_UNITS"""
153        return 1 + xdr_quadlen(self.maxsize)
154
155    def symbolic_width(self) -> List:
156        """Return list containing XDR width of type's components"""
157        widths = ["XDR_unsigned_int"]
158        if self.maxsize != "0":
159            widths.append("XDR_QUADLEN(" + self.maxsize + ")")
160        return widths
161
162    def __post_init__(self):
163        max_widths[self.name] = self.max_width()
164        symbolic_widths[self.name] = self.symbolic_width()
165
166
167@dataclass
168class _XdrString(_XdrDeclaration):
169    """A (NUL-terminated) variable-length string declaration"""
170
171    name: str
172    maxsize: str
173    template: str = "string"
174
175    def max_width(self) -> int:
176        """Return width of type in XDR_UNITS"""
177        return 1 + xdr_quadlen(self.maxsize)
178
179    def symbolic_width(self) -> List:
180        """Return list containing XDR width of type's components"""
181        widths = ["XDR_unsigned_int"]
182        if self.maxsize != "0":
183            widths.append("XDR_QUADLEN(" + self.maxsize + ")")
184        return widths
185
186    def __post_init__(self):
187        max_widths[self.name] = self.max_width()
188        symbolic_widths[self.name] = self.symbolic_width()
189
190
191@dataclass
192class _XdrFixedLengthArray(_XdrDeclaration):
193    """A fixed-length array declaration"""
194
195    name: str
196    spec: _XdrTypeSpecifier
197    size: str
198    template: str = "fixed_length_array"
199
200    def max_width(self) -> int:
201        """Return width of type in XDR_UNITS"""
202        return xdr_quadlen(self.size) * max_widths[self.spec.type_name]
203
204    def symbolic_width(self) -> List:
205        """Return list containing XDR width of type's components"""
206        item_width = " + ".join(symbolic_widths[self.spec.type_name])
207        return ["(" + self.size + " * (" + item_width + "))"]
208
209    def __post_init__(self):
210        max_widths[self.name] = self.max_width()
211        symbolic_widths[self.name] = self.symbolic_width()
212
213
214@dataclass
215class _XdrVariableLengthArray(_XdrDeclaration):
216    """A variable-length array declaration"""
217
218    name: str
219    spec: _XdrTypeSpecifier
220    maxsize: str
221    template: str = "variable_length_array"
222
223    def max_width(self) -> int:
224        """Return width of type in XDR_UNITS"""
225        return 1 + (xdr_quadlen(self.maxsize) * max_widths[self.spec.type_name])
226
227    def symbolic_width(self) -> List:
228        """Return list containing XDR width of type's components"""
229        widths = ["XDR_unsigned_int"]
230        if self.maxsize != "0":
231            item_width = " + ".join(symbolic_widths[self.spec.type_name])
232            widths.append("(" + self.maxsize + " * (" + item_width + "))")
233        return widths
234
235    def __post_init__(self):
236        max_widths[self.name] = self.max_width()
237        symbolic_widths[self.name] = self.symbolic_width()
238
239
240@dataclass
241class _XdrOptionalData(_XdrDeclaration):
242    """An 'optional_data' declaration"""
243
244    name: str
245    spec: _XdrTypeSpecifier
246    template: str = "optional_data"
247
248    def max_width(self) -> int:
249        """Return width of type in XDR_UNITS"""
250        return 1
251
252    def symbolic_width(self) -> List:
253        """Return list containing XDR width of type's components"""
254        return ["XDR_bool"]
255
256    def __post_init__(self):
257        structs.add(self.name)
258        pass_by_reference.add(self.name)
259        max_widths[self.name] = self.max_width()
260        symbolic_widths[self.name] = self.symbolic_width()
261
262
263@dataclass
264class _XdrBasic(_XdrDeclaration):
265    """A 'basic' declaration"""
266
267    name: str
268    spec: _XdrTypeSpecifier
269    template: str = "basic"
270
271    def max_width(self) -> int:
272        """Return width of type in XDR_UNITS"""
273        return max_widths[self.spec.type_name]
274
275    def symbolic_width(self) -> List:
276        """Return list containing XDR width of type's components"""
277        return symbolic_widths[self.spec.type_name]
278
279    def __post_init__(self):
280        max_widths[self.name] = self.max_width()
281        symbolic_widths[self.name] = self.symbolic_width()
282
283
284@dataclass
285class _XdrVoid(_XdrDeclaration):
286    """A void declaration"""
287
288    name: str = "void"
289    template: str = "void"
290
291    def max_width(self) -> int:
292        """Return width of type in XDR_UNITS"""
293        return 0
294
295    def symbolic_width(self) -> List:
296        """Return list containing XDR width of type's components"""
297        return []
298
299
300@dataclass
301class _XdrConstant(_XdrAst):
302    """Corresponds to 'constant_def' in the grammar"""
303
304    name: str
305    value: str
306
307    def __post_init__(self):
308        if self.value not in constants:
309            constants[self.name] = int(self.value, 0)
310
311
312@dataclass
313class _XdrEnumerator(_XdrAst):
314    """An 'identifier = value' enumerator"""
315
316    name: str
317    value: str
318
319    def __post_init__(self):
320        if self.value not in constants:
321            constants[self.name] = int(self.value, 0)
322
323
324@dataclass
325class _XdrEnum(_XdrAst):
326    """An XDR enum definition"""
327
328    name: str
329    minimum: int
330    maximum: int
331    enumerators: List[_XdrEnumerator]
332
333    def max_width(self) -> int:
334        """Return width of type in XDR_UNITS"""
335        return 1
336
337    def symbolic_width(self) -> List:
338        """Return list containing XDR width of type's components"""
339        return ["XDR_int"]
340
341    def __post_init__(self):
342        max_widths[self.name] = self.max_width()
343        symbolic_widths[self.name] = self.symbolic_width()
344
345
346@dataclass
347class _XdrStruct(_XdrAst):
348    """An XDR struct definition"""
349
350    name: str
351    fields: List[_XdrDeclaration]
352
353    def max_width(self) -> int:
354        """Return width of type in XDR_UNITS"""
355        width = 0
356        for field in self.fields:
357            width += field.max_width()
358        return width
359
360    def symbolic_width(self) -> List:
361        """Return list containing XDR width of type's components"""
362        widths = []
363        for field in self.fields:
364            widths += field.symbolic_width()
365        return widths
366
367    def __post_init__(self):
368        structs.add(self.name)
369        pass_by_reference.add(self.name)
370        max_widths[self.name] = self.max_width()
371        symbolic_widths[self.name] = self.symbolic_width()
372
373
374@dataclass
375class _XdrPointer(_XdrAst):
376    """An XDR pointer definition"""
377
378    name: str
379    fields: List[_XdrDeclaration]
380
381    def max_width(self) -> int:
382        """Return width of type in XDR_UNITS"""
383        width = 1
384        for field in self.fields[0:-1]:
385            width += field.max_width()
386        return width
387
388    def symbolic_width(self) -> List:
389        """Return list containing XDR width of type's components"""
390        widths = []
391        widths += ["XDR_bool"]
392        for field in self.fields[0:-1]:
393            widths += field.symbolic_width()
394        return widths
395
396    def __post_init__(self):
397        structs.add(self.name)
398        pass_by_reference.add(self.name)
399        max_widths[self.name] = self.max_width()
400        symbolic_widths[self.name] = self.symbolic_width()
401
402
403@dataclass
404class _XdrTypedef(_XdrAst):
405    """An XDR typedef"""
406
407    declaration: _XdrDeclaration
408
409    def max_width(self) -> int:
410        """Return width of type in XDR_UNITS"""
411        return self.declaration.max_width()
412
413    def symbolic_width(self) -> List:
414        """Return list containing XDR width of type's components"""
415        return self.declaration.symbolic_width()
416
417    def __post_init__(self):
418        if isinstance(self.declaration, _XdrBasic):
419            new_type = self.declaration
420            if isinstance(new_type.spec, _XdrDefinedType):
421                if new_type.spec.type_name in pass_by_reference:
422                    pass_by_reference.add(new_type.name)
423                max_widths[new_type.name] = self.max_width()
424                symbolic_widths[new_type.name] = self.symbolic_width()
425
426
427@dataclass
428class _XdrCaseSpec(_XdrAst):
429    """One case in an XDR union"""
430
431    values: List[str]
432    arm: _XdrDeclaration
433    template: str = "case_spec"
434
435
436@dataclass
437class _XdrDefaultSpec(_XdrAst):
438    """Default case in an XDR union"""
439
440    arm: _XdrDeclaration
441    template: str = "default_spec"
442
443
444@dataclass
445class _XdrUnion(_XdrAst):
446    """An XDR union"""
447
448    name: str
449    discriminant: _XdrDeclaration
450    cases: List[_XdrCaseSpec]
451    default: _XdrDeclaration
452
453    def max_width(self) -> int:
454        """Return width of type in XDR_UNITS"""
455        max_width = 0
456        for case in self.cases:
457            if case.arm.max_width() > max_width:
458                max_width = case.arm.max_width()
459        if self.default:
460            if self.default.arm.max_width() > max_width:
461                max_width = self.default.arm.max_width()
462        return 1 + max_width
463
464    def symbolic_width(self) -> List:
465        """Return list containing XDR width of type's components"""
466        max_width = 0
467        for case in self.cases:
468            if case.arm.max_width() > max_width:
469                max_width = case.arm.max_width()
470                width = case.arm.symbolic_width()
471        if self.default:
472            if self.default.arm.max_width() > max_width:
473                max_width = self.default.arm.max_width()
474                width = self.default.arm.symbolic_width()
475        return symbolic_widths[self.discriminant.name] + width
476
477    def __post_init__(self):
478        structs.add(self.name)
479        pass_by_reference.add(self.name)
480        max_widths[self.name] = self.max_width()
481        symbolic_widths[self.name] = self.symbolic_width()
482
483
484@dataclass
485class _RpcProcedure(_XdrAst):
486    """RPC procedure definition"""
487
488    name: str
489    number: str
490    argument: _XdrTypeSpecifier
491    result: _XdrTypeSpecifier
492
493
494@dataclass
495class _RpcVersion(_XdrAst):
496    """RPC version definition"""
497
498    name: str
499    number: str
500    procedures: List[_RpcProcedure]
501
502
503@dataclass
504class _RpcProgram(_XdrAst):
505    """RPC program definition"""
506
507    name: str
508    number: str
509    versions: List[_RpcVersion]
510
511
512@dataclass
513class _Pragma(_XdrAst):
514    """Empty class for pragma directives"""
515
516
517@dataclass
518class Definition(_XdrAst, ast_utils.WithMeta):
519    """Corresponds to 'definition' in the grammar"""
520
521    meta: Meta
522    value: _XdrAst
523
524
525@dataclass
526class Specification(_XdrAst, ast_utils.AsList):
527    """Corresponds to 'specification' in the grammar"""
528
529    definitions: List[Definition]
530
531
532class ParseToAst(Transformer):
533    """Functions that transform productions into AST nodes"""
534
535    def identifier(self, children):
536        """Instantiate one _XdrIdentifier object"""
537        return _XdrIdentifier(children[0].value)
538
539    def value(self, children):
540        """Instantiate one _XdrValue object"""
541        if isinstance(children[0], _XdrIdentifier):
542            return _XdrValue(children[0].symbol)
543        return _XdrValue(children[0].children[0].value)
544
545    def constant(self, children):
546        """Instantiate one _XdrConstantValue object"""
547        match children[0].data:
548            case "decimal_constant":
549                value = int(children[0].children[0].value, base=10)
550            case "hexadecimal_constant":
551                value = int(children[0].children[0].value, base=16)
552            case "octal_constant":
553                value = int(children[0].children[0].value, base=8)
554        return _XdrConstantValue(value)
555
556    def type_specifier(self, children):
557        """Instantiate one _XdrTypeSpecifier object"""
558        if isinstance(children[0], _XdrIdentifier):
559            name = children[0].symbol
560            return _XdrDefinedType(type_name=name)
561
562        name = children[0].data.value
563        return _XdrBuiltInType(type_name=name)
564
565    def constant_def(self, children):
566        """Instantiate one _XdrConstant object"""
567        name = children[0].symbol
568        value = children[1].value
569        return _XdrConstant(name, value)
570
571    # cel: Python can compute a min() and max() for the enumerator values
572    #      so that the generated code can perform proper range checking.
573    def enum(self, children):
574        """Instantiate one _XdrEnum object"""
575        enum_name = children[0].symbol
576
577        i = 0
578        enumerators = []
579        body = children[1]
580        while i < len(body.children):
581            name = body.children[i].symbol
582            value = body.children[i + 1].value
583            enumerators.append(_XdrEnumerator(name, value))
584            i = i + 2
585
586        return _XdrEnum(enum_name, 0, 0, enumerators)
587
588    def fixed_length_opaque(self, children):
589        """Instantiate one _XdrFixedLengthOpaque declaration object"""
590        name = children[0].symbol
591        size = children[1].value
592
593        return _XdrFixedLengthOpaque(name, size)
594
595    def variable_length_opaque(self, children):
596        """Instantiate one _XdrVariableLengthOpaque declaration object"""
597        name = children[0].symbol
598        if children[1] is not None:
599            maxsize = children[1].value
600        else:
601            maxsize = "0"
602
603        return _XdrVariableLengthOpaque(name, maxsize)
604
605    def string(self, children):
606        """Instantiate one _XdrString declaration object"""
607        name = children[0].symbol
608        if children[1] is not None:
609            maxsize = children[1].value
610        else:
611            maxsize = "0"
612
613        return _XdrString(name, maxsize)
614
615    def fixed_length_array(self, children):
616        """Instantiate one _XdrFixedLengthArray declaration object"""
617        spec = children[0]
618        name = children[1].symbol
619        size = children[2].value
620
621        return _XdrFixedLengthArray(name, spec, size)
622
623    def variable_length_array(self, children):
624        """Instantiate one _XdrVariableLengthArray declaration object"""
625        spec = children[0]
626        name = children[1].symbol
627        if children[2] is not None:
628            maxsize = children[2].value
629        else:
630            maxsize = "0"
631
632        return _XdrVariableLengthArray(name, spec, maxsize)
633
634    def optional_data(self, children):
635        """Instantiate one _XdrOptionalData declaration object"""
636        spec = children[0]
637        name = children[1].symbol
638
639        return _XdrOptionalData(name, spec)
640
641    def basic(self, children):
642        """Instantiate one _XdrBasic object"""
643        spec = children[0]
644        name = children[1].symbol
645
646        return _XdrBasic(name, spec)
647
648    def void(self, children):
649        """Instantiate one _XdrVoid declaration object"""
650
651        return _XdrVoid()
652
653    def struct(self, children):
654        """Instantiate one _XdrStruct object"""
655        name = children[0].symbol
656        fields = children[1].children
657
658        last_field = fields[-1]
659        if (
660            isinstance(last_field, _XdrOptionalData)
661            and name == last_field.spec.type_name
662        ):
663            return _XdrPointer(name, fields)
664
665        return _XdrStruct(name, fields)
666
667    def typedef(self, children):
668        """Instantiate one _XdrTypedef object"""
669        new_type = children[0]
670
671        return _XdrTypedef(new_type)
672
673    def case_spec(self, children):
674        """Instantiate one _XdrCaseSpec object"""
675        values = []
676        for item in children[0:-1]:
677            values.append(item.value)
678        arm = children[-1]
679
680        return _XdrCaseSpec(values, arm)
681
682    def default_spec(self, children):
683        """Instantiate one _XdrDefaultSpec object"""
684        arm = children[0]
685
686        return _XdrDefaultSpec(arm)
687
688    def union(self, children):
689        """Instantiate one _XdrUnion object"""
690        name = children[0].symbol
691
692        body = children[1]
693        discriminant = body.children[0].children[0]
694        cases = body.children[1:-1]
695        default = body.children[-1]
696
697        return _XdrUnion(name, discriminant, cases, default)
698
699    def procedure_def(self, children):
700        """Instantiate one _RpcProcedure object"""
701        result = children[0]
702        name = children[1].symbol
703        argument = children[2]
704        number = children[3].value
705
706        return _RpcProcedure(name, number, argument, result)
707
708    def version_def(self, children):
709        """Instantiate one _RpcVersion object"""
710        name = children[0].symbol
711        number = children[-1].value
712        procedures = children[1:-1]
713
714        return _RpcVersion(name, number, procedures)
715
716    def program_def(self, children):
717        """Instantiate one _RpcProgram object"""
718        name = children[0].symbol
719        number = children[-1].value
720        versions = children[1:-1]
721
722        return _RpcProgram(name, number, versions)
723
724    def pragma_def(self, children):
725        """Instantiate one _Pragma object"""
726        directive = children[0].children[0].data
727        match directive:
728            case "big_endian_directive":
729                big_endian.append(children[1].symbol)
730            case "exclude_directive":
731                excluded_apis.append(children[1].symbol)
732            case "header_directive":
733                global header_name
734                header_name = children[1].symbol
735            case "public_directive":
736                public_apis.append(children[1].symbol)
737            case _:
738                raise NotImplementedError("Directive not supported")
739        return _Pragma()
740
741
742transformer = ast_utils.create_transformer(this_module, ParseToAst())
743
744
745def transform_parse_tree(parse_tree):
746    """Transform productions into an abstract syntax tree"""
747
748    return transformer.transform(parse_tree)
749
750
751def get_header_name() -> str:
752    """Return header name set by pragma header directive"""
753    return header_name
754