xref: /freebsd/crypto/openssl/test/mldsa_wycheproof_parse.py (revision e7be843b4a162e68651d3911f0357ed464915629)
1#!/usr/bin/env python
2# Copyright 2025 The OpenSSL Project Authors. All Rights Reserved.
3#
4# Licensed under the Apache License 2.0 (the "License").  You may not use
5# this file except in compliance with the License.  You can obtain a copy
6# in the file LICENSE in the source distribution or at
7# https://www.openssl.org/source/license.html
8
9# A python program written to parse (version 1) of the WYCHEPROOF test vectors for
10# ML_DSA. The 6 files that can be processed by this utility can be downloaded
11# from
12#  https://github.com/C2SP/wycheproof/blob/8e7fa6f87e6993d7b613cf48b46512a32df8084a/testvectors_v1/mldsa_*_standard_*_test.json")
13# and output from this utility to
14# test/recipes/30-test_evp_data/evppkey_ml_dsa_44_wycheproof_sign.txt
15# test/recipes/30-test_evp_data/evppkey_ml_dsa_65_wycheproof_sign.txt
16# test/recipes/30-test_evp_data/evppkey_ml_dsa_87_wycheproof_sign.txt
17# test/recipes/30-test_evp_data/evppkey_ml_dsa_44_wycheproof_verify.txt
18# test/recipes/30-test_evp_data/evppkey_ml_dsa_65_wycheproof_verify.txt
19# test/recipes/30-test_evp_data/evppkey_ml_dsa_87_wycheproof_verify.txt
20#
21# e.g. python3 ./test/mldsa_wycheproof_parse.py -alg ML-DSA-44 ./wycheproof/testvectors_v1/mldsa_44_standard_sign_test.json > test/recipes/30-test_evp_data/evppkey_ml_dsa_44_wycheproof_sign.txt
22
23import json
24import argparse
25import datetime
26from _ast import Or
27
28def print_label(label, value):
29    print(label + " = " + value)
30
31def print_hexlabel(label, tag, value):
32    print(label + " = hex" + tag + ":" + value)
33
34def parse_ml_dsa_sig_gen(alg, groups):
35    grpId = 1
36    for grp in groups:
37        keyOnly = False
38        first = True
39        name = alg.replace('-', '_')
40        keyname = name + "_" + str(grpId)
41        grpId += 1
42
43        for tst in grp['tests']:
44            if first:
45                first = False
46                if 'flags' in tst:
47                    if 'IncorrectPrivateKeyLength' in tst['flags'] or 'InvalidPrivateKey' in tst['flags']:
48                        keyOnly = True
49                if not keyOnly:
50                    print("")
51                    print_label("PrivateKeyRaw", keyname + ":" + alg + ":" + grp['privateKey'])
52            testname = name + "_" + str(tst['tcId'])
53            print("\n# " + str(tst['tcId']) + " " + tst['comment'])
54
55            print_label("FIPSversion", ">=3.5.0")
56            if keyOnly:
57                print_label("KeyFromData", alg)
58                print_hexlabel("Ctrl", "priv", grp['privateKey'])
59                print_label("Result", "KEY_FROMDATA_ERROR")
60            else:
61                print_label("Sign-Message", alg + ":" + keyname)
62                print_label("Input", tst['msg'])
63                print_label("Output", tst['sig'])
64                if 'ctx' in tst:
65                    print_hexlabel("Ctrl", "context-string", tst['ctx'])
66                print_label("Ctrl", "message-encoding:1")
67                print_label("Ctrl", "deterministic:1")
68                if tst['result'] == "invalid":
69                    print_label("Result", "PKEY_CTRL_ERROR")
70
71def parse_ml_dsa_sig_ver(alg, groups):
72    grpId = 1
73    for grp in groups:
74        keyOnly = False
75        first = True
76        name = alg.replace('-', '_')
77        keyname = name + "_" + str(grpId)
78        grpId += 1
79
80        for tst in grp['tests']:
81            if first:
82                first = False
83                if 'flags' in tst:
84                    if 'IncorrectPublicKeyLength' in tst['flags'] or 'InvalidPublicKey' in tst['flags']:
85                        keyOnly = True
86                if not keyOnly:
87                    print("")
88                    print_label("PublicKeyRaw", keyname + ":" + alg + ":" + grp['publicKey'])
89            testname = name + "_" + str(tst['tcId'])
90            print("\n# " + str(tst['tcId']) + " " + tst['comment'])
91
92            print_label("FIPSversion", ">=3.5.0")
93            if keyOnly:
94                print_label("KeyFromData", alg)
95                print_hexlabel("Ctrl", "pub", grp['publicKey'])
96                print_label("Result", "KEY_FROMDATA_ERROR")
97            else:
98                print_label("Verify-Message-Public", alg + ":" + keyname)
99                print_label("Input", tst['msg'])
100                print_label("Output", tst['sig'])
101                if 'ctx' in tst:
102                    print_hexlabel("Ctrl", "context-string", tst['ctx'])
103                print_label("Ctrl", "message-encoding:1")
104                print_label("Ctrl", "deterministic:1")
105                if tst['result'] == "invalid":
106                    if 'InvalidContext' in tst['flags']:
107                        print_label("Result", "PKEY_CTRL_ERROR")
108                    else:
109                        print_label("Result", "VERIFY_ERROR")
110
111parser = argparse.ArgumentParser(description="")
112parser.add_argument('filename', type=str)
113parser.add_argument('-alg', type=str)
114args = parser.parse_args()
115
116# Open and read the JSON file
117with open(args.filename, 'r') as file:
118    data = json.load(file)
119
120year = datetime.date.today().year
121version = data['generatorVersion']
122algorithm = data['algorithm']
123mode = data['testGroups'][0]['type']
124
125print("# Copyright " + str(year) + " The OpenSSL Project Authors. All Rights Reserved.")
126print("#")
127print("# Licensed under the Apache License 2.0 (the \"License\").  You may not use")
128print("# this file except in compliance with the License.  You can obtain a copy")
129print("# in the file LICENSE in the source distribution or at")
130print("# https://www.openssl.org/source/license.html\n")
131print("# Wycheproof test data for " + algorithm + " " + mode + " generated from")
132print("# https://github.com/C2SP/wycheproof/blob/8e7fa6f87e6993d7b613cf48b46512a32df8084a/testvectors_v1/mldsa_*_standard_*_test.json")
133
134print("# [version " + str(version) + "]")
135
136if algorithm == "ML-DSA":
137    if mode == 'MlDsaSign':
138        parse_ml_dsa_sig_gen(args.alg, data['testGroups'])
139    elif mode == 'MlDsaVerify':
140        parse_ml_dsa_sig_ver(args.alg, data['testGroups'])
141    else:
142        print("Unsupported mode " + mode)
143else:
144    print("Unsupported algorithm " + algorithm)
145