diff --git a/libc/utils/hdrgen/enumeration.py b/libc/utils/hdrgen/enumeration.py index b9848c04ee632..198720826720c 100644 --- a/libc/utils/hdrgen/enumeration.py +++ b/libc/utils/hdrgen/enumeration.py @@ -6,12 +6,24 @@ # # ==-------------------------------------------------------------------------==# +from functools import total_ordering + +@total_ordering class Enumeration: def __init__(self, name, value): self.name = name self.value = value + def __eq__(self, other): + return self.name == other.name + + def __lt__(self, other): + return self.name < other.name + + def __hash__(self): + return self.name.__hash__() + def __str__(self): if self.value != None: return f"{self.name} = {self.value}" diff --git a/libc/utils/hdrgen/function.py b/libc/utils/hdrgen/function.py index 25a7fa80e4501..bccd2c2caa2f5 100644 --- a/libc/utils/hdrgen/function.py +++ b/libc/utils/hdrgen/function.py @@ -7,6 +7,7 @@ # ==-------------------------------------------------------------------------==# import re +from functools import total_ordering from type import Type @@ -36,6 +37,7 @@ NONIDENTIFIER = re.compile("[^a-zA-Z0-9_]+") +@total_ordering class Function: def __init__( self, return_type, name, arguments, standards, guard=None, attributes=[] @@ -51,6 +53,15 @@ def __init__( self.guard = guard self.attributes = attributes or [] + def __eq__(self, other): + return self.name == other.name + + def __lt__(self, other): + return self.name < other.name + + def __hash__(self): + return self.name.__hash__() + def signature_types(self): def collapse(type_string): assert type_string diff --git a/libc/utils/hdrgen/header.py b/libc/utils/hdrgen/header.py index 685e5730e681e..9ea9f98f8fc83 100644 --- a/libc/utils/hdrgen/header.py +++ b/libc/utils/hdrgen/header.py @@ -73,6 +73,7 @@ def __init__(self, name): self.objects = [] self.functions = [] self.standards = [] + self.merge_yaml_files = [] def add_macro(self, macro): self.macros.append(macro) @@ -89,6 +90,13 @@ def add_object(self, object): def add_function(self, function): self.functions.append(function) + def merge(self, other): + self.macros = sorted(set(self.macros) | set(other.macros)) + self.types = sorted(set(self.types) | set(other.types)) + self.enumerations = sorted(set(self.enumerations) | set(other.enumerations)) + self.objects = sorted(set(self.objects) | set(other.objects)) + self.functions = sorted(set(self.functions) | set(other.functions)) + def all_types(self): return reduce( lambda a, b: a | b, diff --git a/libc/utils/hdrgen/macro.py b/libc/utils/hdrgen/macro.py index 19c2b318b8f77..e42e82845694d 100644 --- a/libc/utils/hdrgen/macro.py +++ b/libc/utils/hdrgen/macro.py @@ -6,13 +6,25 @@ # # ==-------------------------------------------------------------------------==# +from functools import total_ordering + +@total_ordering class Macro: def __init__(self, name, value=None, header=None): self.name = name self.value = value self.header = header + def __eq__(self, other): + return self.name == other.name + + def __lt__(self, other): + return self.name < other.name + + def __hash__(self): + return self.name.__hash__() + def __str__(self): if self.header != None: return "" diff --git a/libc/utils/hdrgen/main.py b/libc/utils/hdrgen/main.py index af57845582e5e..27b21ce8ca44b 100755 --- a/libc/utils/hdrgen/main.py +++ b/libc/utils/hdrgen/main.py @@ -52,8 +52,7 @@ def main(): ) args = parser.parse_args() - [yaml_file] = args.yaml_file - files_read = {yaml_file} + files_read = set() def write_depfile(): if not args.depfile: @@ -63,7 +62,34 @@ def write_depfile(): with open(args.depfile, "w") as depfile: depfile.write(f"{args.output}: {deps}\n") - header = load_yaml_file(yaml_file, HeaderFile, args.entry_point) + def load_yaml(path): + files_read.add(path) + return load_yaml_file(path, HeaderFile, args.entry_point) + + merge_from_files = dict() + + def merge_from(paths): + for path in paths: + # Load each file exactly once, in case of redundant merges. + if path in merge_from_files: + continue + header = load_yaml(path) + merge_from_files[path] = header + merge_from(path.parent / f for f in header.merge_yaml_files) + + # Load the main file first. + [yaml_file] = args.yaml_file + header = load_yaml(yaml_file) + + # Now load all the merge_yaml_files, and any transitive merge_yaml_files. + merge_from(yaml_file.parent / f for f in header.merge_yaml_files) + + # Merge in all those files' contents. + for merge_from_path, merge_from_header in merge_from_files.items(): + if merge_from_header.name is not None: + print(f"{merge_from_path!s}: Merge file cannot have header field", stderr) + return 2 + header.merge(merge_from_header) # The header_template path is relative to the containing YAML file. template = header.template(yaml_file.parent, files_read) diff --git a/libc/utils/hdrgen/object.py b/libc/utils/hdrgen/object.py index f5214452f0349..a311c37168d60 100644 --- a/libc/utils/hdrgen/object.py +++ b/libc/utils/hdrgen/object.py @@ -6,11 +6,23 @@ # # ==-------------------------------------------------------------------------==# +from functools import total_ordering + +@total_ordering class Object: def __init__(self, name, type): self.name = name self.type = type + def __eq__(self, other): + return self.name == other.name + + def __lt__(self, other): + return self.name < other.name + + def __hash__(self): + return self.name.__hash__() + def __str__(self): return f"extern {self.type} {self.name};" diff --git a/libc/utils/hdrgen/tests/input/merge1.yaml b/libc/utils/hdrgen/tests/input/merge1.yaml new file mode 100644 index 0000000000000..950abd1770320 --- /dev/null +++ b/libc/utils/hdrgen/tests/input/merge1.yaml @@ -0,0 +1,19 @@ +macros: + - macro_name: MACRO_A + macro_value: 1 +types: + - type_name: type_a +enums: + - name: enum_a + value: value_1 +objects: + - object_name: object_1 + object_type: obj +functions: + - name: func_a + return_type: void + arguments: [] + standards: + - stdc + attributes: + - CONST_FUNC_A diff --git a/libc/utils/hdrgen/tests/input/merge2.yaml b/libc/utils/hdrgen/tests/input/merge2.yaml new file mode 100644 index 0000000000000..a5d741454c3d7 --- /dev/null +++ b/libc/utils/hdrgen/tests/input/merge2.yaml @@ -0,0 +1,18 @@ +macros: + - macro_name: MACRO_B + macro_value: 2 +types: + - type_name: type_b +enums: + - name: enum_b + value: value_2 +objects: + - object_name: object_2 + object_type: obj +functions: + - name: func_b + return_type: float128 + arguments: [] + standards: + - stdc + guard: LIBC_TYPES_HAS_FLOAT128 diff --git a/libc/utils/hdrgen/tests/input/test_small.yaml b/libc/utils/hdrgen/tests/input/test_small.yaml index d5bb2bbfe4468..e072239e9a02a 100644 --- a/libc/utils/hdrgen/tests/input/test_small.yaml +++ b/libc/utils/hdrgen/tests/input/test_small.yaml @@ -1,42 +1,15 @@ header: test_small.h header_template: test_small.h.def +merge_yaml_files: + - merge1.yaml + - merge2.yaml macros: - - macro_name: MACRO_A - macro_value: 1 - - macro_name: MACRO_B - macro_value: 2 - macro_name: MACRO_C - macro_name: MACRO_D macro_header: test_small-macros.h - macro_name: MACRO_E macro_header: test_more-macros.h -types: - - type_name: type_a - - type_name: type_b -enums: - - name: enum_a - value: value_1 - - name: enum_b - value: value_2 -objects: - - object_name: object_1 - object_type: obj - - object_name: object_2 - object_type: obj functions: - - name: func_a - return_type: void - arguments: [] - standards: - - stdc - attributes: - - CONST_FUNC_A - - name: func_b - return_type: float128 - arguments: [] - standards: - - stdc - guard: LIBC_TYPES_HAS_FLOAT128 - name: func_c return_type: _Float16 arguments: diff --git a/libc/utils/hdrgen/type.py b/libc/utils/hdrgen/type.py index 0dbd8a5837d15..0c0af8569c61e 100644 --- a/libc/utils/hdrgen/type.py +++ b/libc/utils/hdrgen/type.py @@ -6,7 +6,10 @@ # # ==-------------------------------------------------------------------------==# +from functools import total_ordering + +@total_ordering class Type: def __init__(self, type_name): assert type_name @@ -15,5 +18,8 @@ def __init__(self, type_name): def __eq__(self, other): return self.type_name == other.type_name + def __lt__(self, other): + return self.type_name < other.type_name + def __hash__(self): return self.type_name.__hash__() diff --git a/libc/utils/hdrgen/yaml_to_classes.py b/libc/utils/hdrgen/yaml_to_classes.py index d7a349648340b..14e1f0f32cbbf 100644 --- a/libc/utils/hdrgen/yaml_to_classes.py +++ b/libc/utils/hdrgen/yaml_to_classes.py @@ -37,6 +37,7 @@ def yaml_to_classes(yaml_data, header_class, entry_points=None): header = header_class(header_name) header.template_file = yaml_data.get("header_template") header.standards = yaml_data.get("standards", []) + header.merge_yaml_files = yaml_data.get("merge_yaml_files", []) for macro_data in yaml_data.get("macros", []): header.add_macro( @@ -126,7 +127,7 @@ def load_yaml_file(yaml_file, header_class, entry_points): Returns: HeaderFile: An instance of HeaderFile populated with the data. """ - with open(yaml_file, "r") as f: + with yaml_file.open() as f: yaml_data = yaml.safe_load(f) return yaml_to_classes(yaml_data, header_class, entry_points) @@ -264,7 +265,7 @@ def main(): add_function_to_yaml(args.yaml_file, args.add_function) header_class = GpuHeader if args.export_decls else HeaderFile - header = load_yaml_file(args.yaml_file, header_class, args.entry_points) + header = load_yaml_file(Path(args.yaml_file), header_class, args.entry_points) header_str = str(header)