1#!/usr/bin/env python3 2 3# Copyright 2026 The OpenSSL Project Authors. All Rights Reserved. 4# 5# Licensed under the Apache License 2.0 (the "License"). You may not use 6# this file except in compliance with the License. You can obtain a copy 7# in the file LICENSE in the source distribution or at 8# https://www.openssl.org/source/license.html 9 10# This script generates missing-kdf.der - a password-encrypted CMS message 11# without the keyDerivationAlgorithm field, which is used in the 12# “PWRI missing keyDerivationAlgorithm regression” test. 13# 14# Usage: python3 make_missing_kdf_der.py valid.der missing-kdf.der 15 16from __future__ import annotations 17 18import argparse 19import sys 20from dataclasses import dataclass 21from pathlib import Path 22 23 24@dataclass 25class Node: 26 off: int 27 tag: int 28 hdr_len: int 29 length: int 30 end: int 31 children: list["Node"] 32 33 34def read_len(data: bytes, off: int) -> tuple[int, int]: 35 first = data[off] 36 if first < 0x80: 37 return first, 1 38 n = first & 0x7F 39 if n == 0 or n > 4: 40 raise ValueError(f"unsupported DER length form at {off}") 41 val = 0 42 for b in data[off + 1 : off + 1 + n]: 43 val = (val << 8) | b 44 return val, 1 + n 45 46 47def parse_node(data: bytes, off: int) -> Node: 48 tag = data[off] 49 length, len_len = read_len(data, off + 1) 50 hdr_len = 1 + len_len 51 end = off + hdr_len + length 52 children: list[Node] = [] 53 if tag & 0x20: 54 cur = off + hdr_len 55 while cur < end: 56 child = parse_node(data, cur) 57 children.append(child) 58 cur = child.end 59 if cur != end: 60 raise ValueError(f"child parse ended at {cur}, expected {end}") 61 return Node(off=off, tag=tag, hdr_len=hdr_len, length=length, end=end, children=children) 62 63 64def encode_len(length: int, existing_len_len: int) -> bytes: 65 if existing_len_len == 1: 66 if length >= 0x80: 67 raise ValueError("new length no longer fits in short-form DER") 68 return bytes([length]) 69 payload_len = existing_len_len - 1 70 max_len = (1 << (payload_len * 8)) - 1 71 if length > max_len: 72 raise ValueError("new length no longer fits in existing long-form DER") 73 out = bytearray([0x80 | payload_len]) 74 for shift in range((payload_len - 1) * 8, -8, -8): 75 out.append((length >> shift) & 0xFF) 76 return bytes(out) 77 78 79def patch_length_field(buf: bytearray, node: Node, delta: int) -> None: 80 new_len = node.length + delta 81 if new_len < 0: 82 raise ValueError("negative patched length") 83 len_bytes = encode_len(new_len, node.hdr_len - 1) 84 start = node.off + 1 85 end = start + len(node.hdr_len.to_bytes(1, "big")) - 1 # unused, kept for clarity 86 buf[start : start + len(len_bytes)] = len_bytes 87 88 89def main() -> int: 90 ap = argparse.ArgumentParser(description="Remove PWRI keyDerivationAlgorithm from a CMS DER blob.") 91 ap.add_argument("input_der") 92 ap.add_argument("output_der") 93 args = ap.parse_args() 94 95 data = Path(args.input_der).read_bytes() 96 root = parse_node(data, 0) 97 98 # CMS structure we expect: 99 # SEQUENCE { OID envelopedData, [0] SEQUENCE { version, SET recipientInfos, ... } } 100 ed_wrapper = root.children[1] 101 env_seq = ed_wrapper.children[0] 102 recipient_set = env_seq.children[1] 103 pwri_choice = recipient_set.children[0] # [3] 104 105 if pwri_choice.tag != 0xA3: 106 raise ValueError(f"expected PWRI choice tag 0xA3, found 0x{pwri_choice.tag:02x}") 107 if len(pwri_choice.children) < 3: 108 raise ValueError("unexpected PWRI child count") 109 110 version = pwri_choice.children[0] 111 maybe_kdf = pwri_choice.children[1] 112 keyenc = pwri_choice.children[2] 113 if version.tag != 0x02: 114 raise ValueError("PWRI version is not INTEGER") 115 if maybe_kdf.tag != 0xA0: 116 raise ValueError(f"PWRI child after version is not [0] keyDerivationAlgorithm: 0x{maybe_kdf.tag:02x}") 117 if keyenc.tag != 0x30: 118 raise ValueError("PWRI keyEncryptionAlgorithm is not SEQUENCE") 119 120 remove_start = maybe_kdf.off 121 remove_end = maybe_kdf.end 122 remove_len = remove_end - remove_start 123 124 out = bytearray(data) 125 del out[remove_start:remove_end] 126 127 # Adjust ancestors whose length spans the removed field. 128 for node in [root, ed_wrapper, env_seq, recipient_set, pwri_choice]: 129 patch_length_field(out, node, -remove_len) 130 131 Path(args.output_der).write_bytes(out) 132 print(f"removed {remove_len} bytes at [{remove_start}, {remove_end})") 133 return 0 134 135 136if __name__ == "__main__": 137 sys.exit(main()) 138