Skip to content

Commit 8c58fe9

Browse files
Merge pull request #7 from AMD-AGI/add_pip_package
Convert GEAK-Agent to installable pip package
2 parents 035b3ff + 85df96c commit 8c58fe9

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

58 files changed

+495
-182
lines changed

.gitignore

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,21 @@
1-
src/__pycache__/
2-
src/*/__pycache__/
3-
src/*/*/__pycache__/
1+
# Python
2+
__pycache__/
3+
*.py[cod]
4+
*$py.class
5+
*.egg-info/
6+
dist/
7+
build/
8+
*.egg
9+
10+
# Virtual environments
11+
venv/
12+
.venv/
13+
14+
# IDE
15+
.idea/
16+
.vscode/
17+
*.swp
18+
*.swo
19+
20+
# Jupyter
21+
.ipynb_checkpoints/

geak_agent/__init__.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
# Copyright(C) [2025] Advanced Micro Devices, Inc. All rights reserved.
2+
3+
"""
4+
GEAK-Agent: An LLM-based multi-agent framework for generating functional and efficient GPU kernels.
5+
"""
6+
7+
__version__ = "0.1.0"
8+
9+
# Lazy imports - only import when accessed to avoid loading tb_eval dependency at import time
10+
__all__ = [
11+
# Version
12+
"__version__",
13+
# Agents
14+
"BaseAgent",
15+
"SequentialBaseAgent",
16+
"GaAgent",
17+
"Reflexion",
18+
"Reflexion_Oneshot",
19+
"OptimAgent",
20+
"DirectPrompt",
21+
# Models
22+
"BaseModel",
23+
"OpenAIModel",
24+
"StandardOpenAIModel",
25+
"ClaudeModel",
26+
"StandardClaudeModel",
27+
"GeminiModel",
28+
# Dataloaders
29+
"TritonBench",
30+
"ROCm",
31+
"ProblemState",
32+
"ProblemStateROCm",
33+
# Utils
34+
"load_config",
35+
]
36+
37+
38+
def __getattr__(name):
39+
"""Lazy import to avoid loading tb_eval at package import time."""
40+
if name in ("BaseAgent", "SequentialBaseAgent"):
41+
from geak_agent.agents.Base import BaseAgent, SequentialBaseAgent
42+
return BaseAgent if name == "BaseAgent" else SequentialBaseAgent
43+
elif name == "GaAgent":
44+
from geak_agent.agents.GaAgent import GaAgent
45+
return GaAgent
46+
elif name == "Reflexion":
47+
from geak_agent.agents.Reflexion import Reflexion
48+
return Reflexion
49+
elif name == "Reflexion_Oneshot":
50+
from geak_agent.agents.reflexion_oneshot import Reflexion_Oneshot
51+
return Reflexion_Oneshot
52+
elif name == "OptimAgent":
53+
from geak_agent.agents.OptimAgent import OptimAgent
54+
return OptimAgent
55+
elif name == "DirectPrompt":
56+
from geak_agent.agents.DirectPrompt import DirectPrompt
57+
return DirectPrompt
58+
elif name == "BaseModel":
59+
from geak_agent.models.Base import BaseModel
60+
return BaseModel
61+
elif name in ("OpenAIModel", "StandardOpenAIModel"):
62+
from geak_agent.models.OpenAI import OpenAIModel, StandardOpenAIModel
63+
return OpenAIModel if name == "OpenAIModel" else StandardOpenAIModel
64+
elif name in ("ClaudeModel", "StandardClaudeModel"):
65+
from geak_agent.models.Claude import ClaudeModel, StandardClaudeModel
66+
return ClaudeModel if name == "ClaudeModel" else StandardClaudeModel
67+
elif name == "GeminiModel":
68+
from geak_agent.models.Gemini import GeminiModel
69+
return GeminiModel
70+
elif name == "TritonBench":
71+
from geak_agent.dataloaders.TritonBench import TritonBench
72+
return TritonBench
73+
elif name == "ROCm":
74+
from geak_agent.dataloaders.ROCm import ROCm
75+
return ROCm
76+
elif name in ("ProblemState", "ProblemStateROCm"):
77+
from geak_agent.dataloaders.ProblemState import ProblemState, ProblemStateROCm
78+
return ProblemState if name == "ProblemState" else ProblemStateROCm
79+
elif name == "load_config":
80+
from geak_agent.args_config import load_config
81+
return load_config
82+
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,9 @@
66
from loguru import logger
77
from dataclasses import dataclass, asdict
88
from concurrent.futures import ThreadPoolExecutor, as_completed
9-
from models.Base import BaseModel
10-
from dataloaders.ProblemState import ProblemState
11-
from memories.Memory import BaseMemory
9+
from geak_agent.models.Base import BaseModel
10+
from geak_agent.dataloaders.ProblemState import ProblemState
11+
from geak_agent.memories.Memory import BaseMemory
1212

1313

1414

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# Copyright(C) [2025] Advanced Micro Devices, Inc. All rights reserved.
22

3-
from agents.Base import BaseAgent
4-
from utils.utils import clear_code
3+
from geak_agent.agents.Base import BaseAgent
4+
from geak_agent.utils.utils import clear_code
55

66
class DirectPrompt(BaseAgent):
77
def run_single_pass(self, mem, verbose=False):
Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,13 @@
44
import os
55
import json
66
from concurrent.futures import ThreadPoolExecutor, as_completed
7-
from agents.reflexion_oneshot import Reflexion_Oneshot
8-
from utils.utils import clear_code, extract_function_signatures, clear_json
9-
from memories.Memory import MemoryClassMeta
10-
from prompts import prompt_for_generation, prompt_for_reflection
7+
from geak_agent.agents.reflexion_oneshot import Reflexion_Oneshot
8+
from geak_agent.utils.utils import clear_code, extract_function_signatures, clear_json
9+
from geak_agent.memories.Memory import MemoryClassMeta
10+
from geak_agent.prompts import prompt_for_generation, prompt_for_reflection
1111
from loguru import logger
1212
from tenacity import RetryError
13-
from dataloaders.ProblemState import tempCode
13+
from geak_agent.dataloaders.ProblemState import tempCode
1414
from typing import List, Optional
1515

1616
class GaAgent(Reflexion_Oneshot):
Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,13 @@
44
import os
55
import json
66
from concurrent.futures import ThreadPoolExecutor, as_completed
7-
from agents.reflexion_oneshot import Reflexion_Oneshot
8-
from utils.utils import clear_code, extract_function_signatures, clear_json
9-
from memories.Memory import MemoryClassMeta
10-
from prompts import prompt_for_generation, prompt_for_reflection
7+
from geak_agent.agents.reflexion_oneshot import Reflexion_Oneshot
8+
from geak_agent.utils.utils import clear_code, extract_function_signatures, clear_json
9+
from geak_agent.memories.Memory import MemoryClassMeta
10+
from geak_agent.prompts import prompt_for_generation, prompt_for_reflection
1111
from loguru import logger
1212
from tenacity import RetryError
13-
from dataloaders.ProblemState import tempCode
13+
from geak_agent.dataloaders.ProblemState import tempCode
1414
from typing import List, Optional
1515
from tb_eval.evaluators.interface import get_evaluators
1616
class GaAgent(Reflexion_Oneshot):
Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,10 @@
44
import os
55
import json
66
from concurrent.futures import ThreadPoolExecutor, as_completed
7-
from agents.reflexion_oneshot import Reflexion_Oneshot
8-
from utils.utils import clear_code, extract_function_signatures, clear_json
9-
from memories.Memory import MemoryClassMeta
10-
from prompts import prompt_for_generation, prompt_for_reflection
7+
from geak_agent.agents.reflexion_oneshot import Reflexion_Oneshot
8+
from geak_agent.utils.utils import clear_code, extract_function_signatures, clear_json
9+
from geak_agent.memories.Memory import MemoryClassMeta
10+
from geak_agent.prompts import prompt_for_generation, prompt_for_reflection
1111
from loguru import logger
1212
from tenacity import RetryError
1313

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
# Copyright(C) [2025] Advanced Micro Devices, Inc. All rights reserved.
22

3-
from agents.Base import SequentialBaseAgent, BaseAgent
4-
from utils.utils import clear_code
5-
from prompts import prompt_for_reflection
6-
from memories.Memory import ReflexionMemory
7-
from models.Base import BaseModel
3+
from geak_agent.agents.Base import SequentialBaseAgent, BaseAgent
4+
from geak_agent.utils.utils import clear_code
5+
from geak_agent.prompts import prompt_for_reflection
6+
from geak_agent.memories.Memory import ReflexionMemory
7+
from geak_agent.models.Base import BaseModel
88

99

1010

geak_agent/agents/__init__.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
# Copyright(C) [2025] Advanced Micro Devices, Inc. All rights reserved.
2+
3+
from geak_agent.agents.Base import BaseAgent, SequentialBaseAgent
4+
from geak_agent.agents.GaAgent import GaAgent
5+
from geak_agent.agents.GaAgent_ROCm import GaAgent as GaAgentROCm
6+
from geak_agent.agents.Reflexion import Reflexion
7+
from geak_agent.agents.reflexion_oneshot import Reflexion_Oneshot
8+
from geak_agent.agents.OptimAgent import OptimAgent
9+
from geak_agent.agents.DirectPrompt import DirectPrompt
10+
11+
__all__ = [
12+
"BaseAgent",
13+
"SequentialBaseAgent",
14+
"GaAgent",
15+
"GaAgentROCm",
16+
"Reflexion",
17+
"Reflexion_Oneshot",
18+
"OptimAgent",
19+
"DirectPrompt",
20+
]
21+
Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,13 @@
55
from loguru import logger
66
import json
77
from dataclasses import asdict
8-
from agents.Reflexion import Reflexion
9-
from utils.utils import extract_function_signatures, clear_code, extract_function_calls
10-
from prompts import prompt_for_reflection
11-
from memories.Memory import MemoryClassMeta
12-
from models.Base import BaseModel
13-
from retrievers.retriever import BM25Retriever
14-
from prompts import prompt_for_generation
8+
from geak_agent.agents.Reflexion import Reflexion
9+
from geak_agent.utils.utils import extract_function_signatures, clear_code, extract_function_calls
10+
from geak_agent.prompts import prompt_for_reflection
11+
from geak_agent.memories.Memory import MemoryClassMeta
12+
from geak_agent.models.Base import BaseModel
13+
from geak_agent.retrievers.retriever import BM25Retriever
14+
from geak_agent.prompts import prompt_for_generation
1515
from concurrent.futures import ThreadPoolExecutor, as_completed
1616

1717

0 commit comments

Comments
 (0)