|
| 1 | +import json |
| 2 | +import logging |
| 3 | +import time |
| 4 | +from collections import defaultdict |
| 5 | +from datetime import timedelta |
| 6 | +from pathlib import Path |
| 7 | +from subprocess import run |
| 8 | +from typing import Any, Dict, List, Optional, Set |
| 9 | + |
| 10 | +from aws_doc_sdk_examples_tools.lliam.domain.commands import RunAilly |
| 11 | +from aws_doc_sdk_examples_tools.lliam.config import ( |
| 12 | + AILLY_DIR_PATH, |
| 13 | + BATCH_PREFIX, |
| 14 | +) |
| 15 | + |
| 16 | +logger = logging.getLogger(__file__) |
| 17 | + |
| 18 | + |
| 19 | +def handle_run_ailly(cmd: RunAilly, uow: None): |
| 20 | + resolved_batches = resolve_requested_batches(cmd.batches) |
| 21 | + |
| 22 | + if resolved_batches: |
| 23 | + total_start_time = time.time() |
| 24 | + |
| 25 | + for batch in resolved_batches: |
| 26 | + run_ailly_single_batch(batch) |
| 27 | + |
| 28 | + total_end_time = time.time() |
| 29 | + total_duration = total_end_time - total_start_time |
| 30 | + num_batches = len(resolved_batches) |
| 31 | + logger.info( |
| 32 | + f"[TIMECHECK] {num_batches} batches took {format_duration(total_duration)} to run" |
| 33 | + ) |
| 34 | + |
| 35 | + |
| 36 | +def resolve_requested_batches(batch_names: List[str]) -> List[Path]: |
| 37 | + if not batch_names: |
| 38 | + batch_paths = [ |
| 39 | + p |
| 40 | + for p in AILLY_DIR_PATH.iterdir() |
| 41 | + if p.is_dir() and p.name.startswith(BATCH_PREFIX) |
| 42 | + ] |
| 43 | + |
| 44 | + return batch_paths |
| 45 | + |
| 46 | + batch_paths = [] |
| 47 | + |
| 48 | + for batch_name in batch_names: |
| 49 | + batch_path = Path(AILLY_DIR_PATH / batch_name) |
| 50 | + if not batch_path.exists(): |
| 51 | + raise FileNotFoundError(batch_path) |
| 52 | + if not batch_path.is_dir(): |
| 53 | + raise NotADirectoryError(batch_path) |
| 54 | + batch_paths.append(batch_path) |
| 55 | + |
| 56 | + return batch_paths |
| 57 | + |
| 58 | + |
| 59 | +def run_ailly_single_batch(batch: Path) -> None: |
| 60 | + """Run ailly and process files for a single batch.""" |
| 61 | + batch_start_time = time.time() |
| 62 | + iam_updates_path = AILLY_DIR_PATH / f"updates_{batch.name}.json" |
| 63 | + |
| 64 | + cmd = [ |
| 65 | + "ailly", |
| 66 | + "--max-depth", |
| 67 | + "10", |
| 68 | + "--root", |
| 69 | + str(AILLY_DIR_PATH), |
| 70 | + batch.name, |
| 71 | + ] |
| 72 | + logger.info(f"Running {cmd}") |
| 73 | + run(cmd) |
| 74 | + |
| 75 | + batch_end_time = time.time() |
| 76 | + batch_duration = batch_end_time - batch_start_time |
| 77 | + logger.info( |
| 78 | + f"[TIMECHECK] {batch.name} took {format_duration(batch_duration)} to run" |
| 79 | + ) |
| 80 | + |
| 81 | + logger.info(f"Processing generated content for {batch.name}") |
| 82 | + process_ailly_files(input_dir=batch, output_file=iam_updates_path) |
| 83 | + |
| 84 | + |
| 85 | +EXPECTED_KEYS: Set[str] = set(["title", "title_abbrev"]) |
| 86 | +VALUE_PREFIXES: Dict[str, str] = {"title": "", "title_abbrev": "", "synopsis": ""} |
| 87 | + |
| 88 | + |
| 89 | +class MissingExpectedKeys(Exception): |
| 90 | + pass |
| 91 | + |
| 92 | + |
| 93 | +def parse_fenced_blocks(content: str, fence="===") -> List[List[str]]: |
| 94 | + blocks = [] |
| 95 | + inside_fence = False |
| 96 | + current_block: List[str] = [] |
| 97 | + |
| 98 | + for line in content.splitlines(): |
| 99 | + if line.strip() == fence: |
| 100 | + if inside_fence: |
| 101 | + blocks.append(current_block) |
| 102 | + current_block = [] |
| 103 | + inside_fence = not inside_fence |
| 104 | + elif inside_fence: |
| 105 | + current_block.append(line) |
| 106 | + |
| 107 | + return blocks |
| 108 | + |
| 109 | + |
| 110 | +def parse_block_lines( |
| 111 | + block: List[str], key_pairs: Dict[str, str], expected_keys=EXPECTED_KEYS |
| 112 | +): |
| 113 | + for line in block: |
| 114 | + if "=>" in line: |
| 115 | + parts = line.split("=>", 1) |
| 116 | + key = parts[0].strip() |
| 117 | + value = parts[1].strip() if len(parts) > 1 else "" |
| 118 | + key_pairs[key] = value |
| 119 | + if missing_keys := expected_keys - key_pairs.keys(): |
| 120 | + raise MissingExpectedKeys(missing_keys) |
| 121 | + |
| 122 | + |
| 123 | +def parse_ailly_file( |
| 124 | + file_path: str, value_prefixes: Dict[str, str] = VALUE_PREFIXES |
| 125 | +) -> Dict[str, Any]: |
| 126 | + """ |
| 127 | + Parse an .md.ailly.md file and extract key-value pairs that are between === fence markers. Each |
| 128 | + key value pair is assumed to be on one line and in the form of `key => value`. This formatting is |
| 129 | + totally dependent on the LLM output written by Ailly. |
| 130 | +
|
| 131 | + Args: |
| 132 | + file_path: Path to the .md.ailly.md file |
| 133 | +
|
| 134 | + Returns: |
| 135 | + Dictionary containing the extracted key-value pairs |
| 136 | + """ |
| 137 | + result: Dict[str, str] = {} |
| 138 | + |
| 139 | + try: |
| 140 | + with open(file_path, "r", encoding="utf-8") as file: |
| 141 | + content = file.read() |
| 142 | + |
| 143 | + blocks = parse_fenced_blocks(content) |
| 144 | + |
| 145 | + for block in blocks: |
| 146 | + parse_block_lines(block, result) |
| 147 | + |
| 148 | + for key, prefix in value_prefixes.items(): |
| 149 | + if key in result: |
| 150 | + result[key] = f"{prefix}{result[key]}" |
| 151 | + |
| 152 | + result["id"] = Path(file_path).name.split(".md.ailly.md")[0] |
| 153 | + result["_source_file"] = file_path |
| 154 | + |
| 155 | + except Exception as e: |
| 156 | + logger.error(f"Error parsing file {file_path}", exc_info=e) |
| 157 | + |
| 158 | + return result |
| 159 | + |
| 160 | + |
| 161 | +def parse_package_name(policy_update: Dict[str, str]) -> Optional[str]: |
| 162 | + if not policy_update: |
| 163 | + return None |
| 164 | + |
| 165 | + if not isinstance(policy_update, dict): |
| 166 | + return None |
| 167 | + |
| 168 | + if not (id := policy_update.get("id")): |
| 169 | + return None |
| 170 | + |
| 171 | + id_parts = [part.strip() for part in id.split(".")] |
| 172 | + |
| 173 | + if id_parts[0] != "iam-policies": |
| 174 | + return None |
| 175 | + |
| 176 | + return id_parts[1] # The package name, hopefully. |
| 177 | + |
| 178 | + |
| 179 | +def process_ailly_files( |
| 180 | + input_dir: Path, output_file: Path, file_pattern: str = "*.md.ailly.md" |
| 181 | +) -> None: |
| 182 | + """ |
| 183 | + Process all .md.ailly.md files in the input directory and write the results as JSON to the output file. |
| 184 | +
|
| 185 | + Args: |
| 186 | + input_dir: Directory containing .md.ailly.md files |
| 187 | + output_file: Path to the output JSON file |
| 188 | + file_pattern: Pattern to match files (default: "*.md.ailly.md") |
| 189 | + """ |
| 190 | + results = defaultdict(list) |
| 191 | + |
| 192 | + try: |
| 193 | + for file_path in input_dir.rglob(file_pattern): |
| 194 | + logger.info(f"Processing file: {file_path}") |
| 195 | + policy_update = parse_ailly_file(str(file_path)) |
| 196 | + if policy_update: |
| 197 | + package_name = parse_package_name(policy_update) |
| 198 | + if not package_name: |
| 199 | + raise TypeError(f"Could not get package name from policy update.") |
| 200 | + results[package_name].append(policy_update) |
| 201 | + |
| 202 | + with open(output_file, "w", encoding="utf-8") as out_file: |
| 203 | + json.dump(results, out_file, indent=2) |
| 204 | + |
| 205 | + logger.info( |
| 206 | + f"Successfully processed files. Output written to {output_file.name}" |
| 207 | + ) |
| 208 | + |
| 209 | + except Exception as e: |
| 210 | + logger.error("Error processing files", exc_info=e) |
| 211 | + |
| 212 | + |
| 213 | +def format_duration(seconds: float) -> str: |
| 214 | + td = timedelta(seconds=seconds) |
| 215 | + return str(td).zfill(8) |
0 commit comments