xref: /linux/tools/net/sunrpc/xdrgen/generators/program.py (revision 23b0f90ba871f096474e1c27c3d14f455189d2d9)
1#!/usr/bin/env python3
2# ex: set filetype=python:
3
4"""Generate code for an RPC program's procedures"""
5
6from jinja2 import Environment
7
8from generators import SourceGenerator, create_jinja2_environment, get_jinja2_template
9from xdr_ast import _RpcProgram, _RpcVersion, excluded_apis
10from xdr_ast import max_widths, get_header_name
11
12
13def emit_version_definitions(
14    environment: Environment, program: str, version: _RpcVersion
15) -> None:
16    """Emit procedure numbers for each RPC version's procedures"""
17    template = environment.get_template("definition/open.j2")
18    print(template.render(program=program.upper()))
19
20    template = environment.get_template("definition/procedure.j2")
21    for procedure in version.procedures:
22        if procedure.name not in excluded_apis:
23            print(
24                template.render(
25                    name=procedure.name,
26                    value=procedure.number,
27                )
28            )
29
30    template = environment.get_template("definition/close.j2")
31    print(template.render())
32
33
34def emit_version_declarations(
35    environment: Environment, program: str, version: _RpcVersion
36) -> None:
37    """Emit declarations for each RPC version's procedures"""
38    arguments = dict.fromkeys([])
39    for procedure in version.procedures:
40        if procedure.name not in excluded_apis:
41            arguments[procedure.argument.type_name] = None
42    if len(arguments) > 0:
43        print("")
44        template = environment.get_template("declaration/argument.j2")
45        for argument in arguments:
46            print(template.render(program=program, argument=argument))
47
48    results = dict.fromkeys([])
49    for procedure in version.procedures:
50        if procedure.name not in excluded_apis:
51            results[procedure.result.type_name] = None
52    if len(results) > 0:
53        print("")
54        template = environment.get_template("declaration/result.j2")
55        for result in results:
56            print(template.render(program=program, result=result))
57
58
59def emit_version_argument_decoders(
60    environment: Environment, program: str, version: _RpcVersion
61) -> None:
62    """Emit server argument decoders for each RPC version's procedures"""
63    arguments = dict.fromkeys([])
64    for procedure in version.procedures:
65        if procedure.name not in excluded_apis:
66            arguments[procedure.argument.type_name] = None
67
68    template = environment.get_template("decoder/argument.j2")
69    for argument in arguments:
70        print(template.render(program=program, argument=argument))
71
72
73def emit_version_result_decoders(
74    environment: Environment, program: str, version: _RpcVersion
75) -> None:
76    """Emit client result decoders for each RPC version's procedures"""
77    results = dict.fromkeys([])
78    for procedure in version.procedures:
79        if procedure.name not in excluded_apis:
80            results[procedure.result.type_name] = None
81
82    template = environment.get_template("decoder/result.j2")
83    for result in results:
84        print(template.render(program=program, result=result))
85
86
87def emit_version_argument_encoders(
88    environment: Environment, program: str, version: _RpcVersion
89) -> None:
90    """Emit client argument encoders for each RPC version's procedures"""
91    arguments = dict.fromkeys([])
92    for procedure in version.procedures:
93        if procedure.name not in excluded_apis:
94            arguments[procedure.argument.type_name] = None
95
96    template = environment.get_template("encoder/argument.j2")
97    for argument in arguments:
98        print(template.render(program=program, argument=argument))
99
100
101def emit_version_result_encoders(
102    environment: Environment, program: str, version: _RpcVersion
103) -> None:
104    """Emit server result encoders for each RPC version's procedures"""
105    results = dict.fromkeys([])
106    for procedure in version.procedures:
107        if procedure.name not in excluded_apis:
108            results[procedure.result.type_name] = None
109
110    template = environment.get_template("encoder/result.j2")
111    for result in results:
112        print(template.render(program=program, result=result))
113
114
115class XdrProgramGenerator(SourceGenerator):
116    """Generate source code for an RPC program's procedures"""
117
118    def __init__(self, language: str, peer: str):
119        """Initialize an instance of this class"""
120        self.environment = create_jinja2_environment(language, "program")
121        self.peer = peer
122
123    def emit_definition(self, node: _RpcProgram) -> None:
124        """Emit procedure numbers for each of an RPC programs's procedures"""
125        raw_name = node.name
126        program = raw_name.lower().removesuffix("_program").removesuffix("_prog")
127
128        for version in node.versions:
129            emit_version_definitions(self.environment, program, version)
130
131        template = self.environment.get_template("definition/program.j2")
132        print(template.render(name=raw_name, value=node.number))
133
134    def emit_declaration(self, node: _RpcProgram) -> None:
135        """Emit a declaration pair for each of an RPC programs's procedures"""
136        raw_name = node.name
137        program = raw_name.lower().removesuffix("_program").removesuffix("_prog")
138
139        for version in node.versions:
140            emit_version_declarations(self.environment, program, version)
141
142    def emit_decoder(self, node: _RpcProgram) -> None:
143        """Emit all decoder functions for an RPC program's procedures"""
144        raw_name = node.name
145        program = raw_name.lower().removesuffix("_program").removesuffix("_prog")
146        match self.peer:
147            case "server":
148                for version in node.versions:
149                    emit_version_argument_decoders(
150                        self.environment, program, version,
151                    )
152            case "client":
153                for version in node.versions:
154                    emit_version_result_decoders(
155                        self.environment, program, version,
156                    )
157
158    def emit_encoder(self, node: _RpcProgram) -> None:
159        """Emit all encoder functions for an RPC program's procedures"""
160        raw_name = node.name
161        program = raw_name.lower().removesuffix("_program").removesuffix("_prog")
162        match self.peer:
163            case "server":
164                for version in node.versions:
165                    emit_version_result_encoders(
166                        self.environment, program, version,
167                    )
168            case "client":
169                for version in node.versions:
170                    emit_version_argument_encoders(
171                        self.environment, program, version,
172                    )
173
174    def emit_maxsize(self, node: _RpcProgram) -> None:
175        """Emit maxsize macro for maximum RPC argument size"""
176        header = get_header_name().upper()
177
178        # Find the largest argument across all versions
179        max_arg_width = 0
180        max_arg_name = None
181        for version in node.versions:
182            for procedure in version.procedures:
183                if procedure.name in excluded_apis:
184                    continue
185                arg_name = procedure.argument.type_name
186                if arg_name == "void":
187                    continue
188                if arg_name not in max_widths:
189                    continue
190                if max_widths[arg_name] > max_arg_width:
191                    max_arg_width = max_widths[arg_name]
192                    max_arg_name = arg_name
193
194        if max_arg_name is None:
195            return
196
197        macro_name = header + "_MAX_ARGS_SZ"
198        template = get_jinja2_template(self.environment, "maxsize", "max_args")
199        print(
200            template.render(
201                macro=macro_name,
202                width=header + "_" + max_arg_name + "_sz",
203            )
204        )
205