Skip to content

Commit d8d3faa

Browse files
authored
Add Lliam UoW (#177)
* Add UoW and adapters to Lliam * Update Lliam repository to allow injectable FS * Make batch num padding dynamic, and clean up a few things.
1 parent 912b261 commit d8d3faa

File tree

8 files changed

+243
-1
lines changed

8 files changed

+243
-1
lines changed
Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
from itertools import islice
2+
from math import ceil, floor, log
3+
from pathlib import Path
4+
from typing import Any, Dict, Generator, Iterable, List, Tuple
5+
6+
from aws_doc_sdk_examples_tools.doc_gen import DocGen, Example
7+
from aws_doc_sdk_examples_tools.fs import Fs, PathFs
8+
from aws_doc_sdk_examples_tools.lliam.domain.model import Prompt
9+
from aws_doc_sdk_examples_tools.lliam.config import BATCH_PREFIX
10+
11+
DEFAULT_METADATA_PREFIX = "DEFAULT"
12+
DEFAULT_BATCH_SIZE = 150
13+
IAM_POLICY_LANGUAGE = "IAMPolicyGrammar"
14+
15+
16+
def batched(iterable: Iterable, n: int) -> Generator[Tuple, Any, None]:
17+
"Batch data into tuples of length n. The last batch may be shorter."
18+
# batched('ABCDEFG', 3) --> ABC DEF G
19+
if n < 1:
20+
raise ValueError("n must be at least one")
21+
it = iter(iterable)
22+
while batch := tuple(islice(it, n)):
23+
yield batch
24+
25+
26+
class PromptRepository:
27+
to_write: Dict[str, str] = {}
28+
29+
def __init__(self, fs: Fs = PathFs()):
30+
self.fs = fs
31+
32+
def rollback(self):
33+
# TODO: This is not what rollback is for. We should be rolling back any
34+
# file changes
35+
self.to_write = {}
36+
37+
def add(self, prompt: Prompt):
38+
self.to_write[prompt.id] = prompt.content
39+
40+
def all_all(self, prompts: Iterable[Prompt]):
41+
for prompt in prompts:
42+
self.add(prompt)
43+
44+
def batch(self, prompts: Iterable[Prompt]):
45+
prompt_list = list(prompts)
46+
47+
if not prompt_list:
48+
return
49+
50+
batches_count = ceil(len(prompt_list) / DEFAULT_BATCH_SIZE)
51+
padding = floor(log(batches_count, 10)) + 1
52+
for batch_num, batch in enumerate(batched(prompts, DEFAULT_BATCH_SIZE)):
53+
batch_name = f"{BATCH_PREFIX}{(batch_num + 1):0{padding}}"
54+
for prompt in batch:
55+
prompt.id = f"{batch_name}/{prompt.id}"
56+
self.add(prompt)
57+
58+
def commit(self):
59+
base_path = Path(self.partition) if self.partition else Path(".")
60+
61+
for id, content in self.to_write.items():
62+
if content:
63+
full_path = base_path / id
64+
self.fs.mkdir(full_path.parent)
65+
self.fs.write(full_path, content)
66+
67+
def get(self, id: str):
68+
return Prompt(id, self.fs.read(Path(id)))
69+
70+
def get_all(self, ids: List[str]) -> List[Prompt]:
71+
prompts = []
72+
for id in ids:
73+
prompt = self.get(id)
74+
prompts.append(prompt)
75+
return prompts
76+
77+
def set_partition(self, name: str):
78+
self.partition_name = name
79+
80+
@property
81+
def partition(self):
82+
return self.partition_name or ""
83+
84+
85+
class DocGenRepository:
86+
def __init__(self, fs: Fs = PathFs()):
87+
self.fs = fs
88+
89+
def rollback(self):
90+
# TODO: This is not what rollback is for. We should be rolling back any
91+
# file changes
92+
self._doc_gen = None
93+
94+
def get_new_prompts(self, doc_gen_root: str) -> List[Prompt]:
95+
# Right now this is the only instance of DocGen used in this Repository,
96+
# but if that changes we need to move it up.
97+
self._doc_gen = DocGen.from_root(Path(doc_gen_root), fs=self.fs)
98+
self._doc_gen.collect_snippets()
99+
new_examples = self._get_new_examples()
100+
prompts = self._examples_to_prompts(new_examples)
101+
return prompts
102+
103+
def _get_new_examples(self) -> List[Tuple[str, Example]]:
104+
examples = self._doc_gen.examples
105+
106+
filtered_examples: List[Tuple[str, Example]] = []
107+
for example_id, example in examples.items():
108+
# TCXContentAnalyzer prefixes new metadata title/title_abbrev entries with
109+
# the DEFAULT_METADATA_PREFIX. Checking this here to make sure we're only
110+
# running the LLM tool on new extractions.
111+
title = example.title or ""
112+
title_abbrev = example.title_abbrev or ""
113+
if title.startswith(DEFAULT_METADATA_PREFIX) and title_abbrev.startswith(
114+
DEFAULT_METADATA_PREFIX
115+
):
116+
filtered_examples.append((example_id, example))
117+
return filtered_examples
118+
119+
def _examples_to_prompts(self, examples: List[Tuple[str, Example]]) -> List[Prompt]:
120+
snippets = self._doc_gen.snippets
121+
prompts = []
122+
for example_id, example in examples:
123+
key = (
124+
example.languages[IAM_POLICY_LANGUAGE]
125+
.versions[0]
126+
.excerpts[0]
127+
.snippet_files[0]
128+
.replace("/", ".")
129+
)
130+
snippet = snippets.get(key)
131+
prompts.append(Prompt(f"{example_id}.md", snippet.code))
132+
return prompts
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
from pathlib import Path
2+
3+
AILLY_DIR = ".ailly_iam_policy"
4+
AILLY_DIR_PATH = Path(AILLY_DIR)
5+
BATCH_PREFIX = "batch_"

aws_doc_sdk_examples_tools/lliam/domain/commands.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ class Command:
99

1010
@dataclass
1111
class CreatePrompts(Command):
12-
doc_gen_root: str
12+
doc_gen_root: str
1313
system_prompts: List[str]
1414
out_dir: str
1515

aws_doc_sdk_examples_tools/lliam/service_layer/__init__.py

Whitespace-only changes.
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
import logging
2+
3+
from aws_doc_sdk_examples_tools.lliam.domain.operations import build_ailly_config
4+
from aws_doc_sdk_examples_tools.lliam.domain.commands import CreatePrompts
5+
from aws_doc_sdk_examples_tools.lliam.service_layer.unit_of_work import FsUnitOfWork
6+
7+
logger = logging.getLogger(__name__)
8+
9+
10+
def create_prompts(cmd: CreatePrompts, uow: FsUnitOfWork):
11+
with uow:
12+
system_prompts = uow.prompts.get_all(cmd.system_prompts)
13+
ailly_config = build_ailly_config(system_prompts)
14+
prompts = uow.doc_gen.get_new_prompts(cmd.doc_gen_root)
15+
uow.prompts.batch(prompts)
16+
uow.prompts.add(ailly_config)
17+
uow.prompts.set_partition(cmd.out_dir)
18+
uow.commit()
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
from aws_doc_sdk_examples_tools.fs import Fs, PathFs
2+
from aws_doc_sdk_examples_tools.lliam.adapters.repository import (
3+
PromptRepository,
4+
DocGenRepository,
5+
)
6+
7+
8+
class FsUnitOfWork:
9+
def __init__(self, fs: Fs = PathFs()):
10+
self.fs = fs
11+
12+
def __enter__(self):
13+
self.prompts = PromptRepository(fs=self.fs)
14+
self.doc_gen = DocGenRepository(fs=self.fs)
15+
16+
def __exit__(self, *args):
17+
self.rollback()
18+
19+
def commit(self):
20+
self.prompts.commit()
21+
22+
def rollback(self):
23+
self.prompts.rollback()
24+
self.doc_gen.rollback()
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
from pathlib import Path
2+
3+
from aws_doc_sdk_examples_tools.fs import RecordFs
4+
from aws_doc_sdk_examples_tools.lliam.adapters.repository import PromptRepository
5+
from aws_doc_sdk_examples_tools.lliam.domain.model import Prompt
6+
7+
8+
def test_batch_naming_occurs_properly():
9+
"""Test that batch naming occurs properly when batching prompts."""
10+
fs = RecordFs({})
11+
repo = PromptRepository(fs=fs)
12+
13+
prompts = []
14+
for i in range(300):
15+
prompts.append(Prompt(f"prompt_{i}.md", f"Content for prompt {i}"))
16+
17+
repo.batch(prompts)
18+
19+
expected_batch_1_prompts = 150
20+
expected_batch_2_prompts = 150
21+
22+
batch_1_count = 0
23+
batch_2_count = 0
24+
for prompt_id in repo.to_write:
25+
if prompt_id.startswith("batch_1/"):
26+
batch_1_count += 1
27+
elif prompt_id.startswith("batch_2/"):
28+
batch_2_count += 1
29+
30+
assert batch_1_count == expected_batch_1_prompts
31+
assert batch_2_count == expected_batch_2_prompts
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
from pathlib import Path
2+
3+
from aws_doc_sdk_examples_tools.fs import RecordFs
4+
from aws_doc_sdk_examples_tools.lliam.domain.commands import CreatePrompts
5+
from aws_doc_sdk_examples_tools.lliam.service_layer.create_prompts import create_prompts
6+
from aws_doc_sdk_examples_tools.lliam.service_layer.unit_of_work import FsUnitOfWork
7+
8+
9+
def test_create_prompts_writes_when_commit_called():
10+
"""Test that create_prompts successfully writes prompts when commit is called."""
11+
fs = RecordFs(
12+
{
13+
Path("/system1.md"): "System prompt 1 content",
14+
Path("/system2.md"): "System prompt 2 content",
15+
Path("/fake/doc_gen_root"): "",
16+
}
17+
)
18+
uow = FsUnitOfWork(fs=fs)
19+
cmd = CreatePrompts(
20+
doc_gen_root="/fake/doc_gen_root",
21+
system_prompts=["/system1.md", "/system2.md"],
22+
out_dir="/fake/output",
23+
)
24+
25+
create_prompts(cmd, uow)
26+
27+
# Ailly config should be in committed prompts
28+
ailly_config_path = Path("/fake/output/.aillyrc")
29+
assert fs.stat(ailly_config_path).exists
30+
ailly_config_content = fs.read(ailly_config_path)
31+
assert "System prompt 1 content" in ailly_config_content
32+
assert "System prompt 2 content" in ailly_config_content

0 commit comments

Comments
 (0)