Skip to content

Stream output of a user project with serve #292

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 17 commits into
base: deploy-command
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 5 additions & 120 deletions agentstack/cli/run.py
Original file line number Diff line number Diff line change
@@ -1,111 +1,12 @@
from typing import Optional, List, Dict
import sys
import asyncio
import traceback
from pathlib import Path
import importlib.util
from dotenv import load_dotenv

from typing import Optional
from agentstack import conf, log
from agentstack.exceptions import ValidationError
from agentstack import inputs
from agentstack import frameworks
from agentstack.utils import get_framework, verify_agentstack_project

MAIN_FILENAME: Path = Path("src/main.py")
MAIN_MODULE_NAME = "main"


def format_friendly_error_message(exception: Exception):
"""
Projects will throw various errors, especially on first runs, so we catch
them here and print a more helpful message.

In order to prevent us from having to import all possible backend exceptions
we do string matching on the exception type and traceback contents.
"""
# TODO These end up being pretty framework-specific; consider individual implementations.
COMMON_LLM_ENV_VARS = (
'OPENAI_API_KEY',
'ANTHROPIC_API_KEY',
)

name = exception.__class__.__name__
message = str(exception)
tracebacks = traceback.format_exception(type(exception), exception, exception.__traceback__)

match (name, message, tracebacks):
# The user doesn't have an environment variable set for the LLM provider.
case ('AuthenticationError', m, t) if 'litellm.AuthenticationError' in t[-1]:
variable_name = [k for k in COMMON_LLM_ENV_VARS if k in message] or ["correct"]
return (
"We were unable to connect to the LLM provider. "
f"Ensure your .env file has the {variable_name[0]} variable set."
)
# This happens when the LLM configured for an agent is invalid.
case ('BadRequestError', m, t) if 'LLM Provider NOT provided' in t[-1]:
return (
"An invalid LLM was configured for an agent. "
"Ensure the 'llm' attribute of the agent in the agents.yaml file is in the format <provider>/<model>."
)
# The user has not configured the correct agent name in the tasks.yaml file.
case ('KeyError', m, t) if 'self.tasks_config[task_name]["agent"]' in t[-2]:
return (
f"The agent {message} is not defined in your agents file. "
"Ensure the 'agent' fields in your tasks.yaml correspond to an entry in the agents.yaml file."
)
# The user does not have an agent defined in agents.yaml file, but it does
# exist in the entrypoint code.
case ('KeyError', m, t) if 'config=self.agents_config[' in t[-2]:
return (
f"The agent {message} is not defined in your agents file. "
"Ensure all agents referenced in your code are defined in the agents.yaml file."
)
# The user does not have a task defined in tasks.yaml file, but it does
# exist in the entrypoint code.
case ('KeyError', m, t) if 'config=self.tasks_config[' in t[-2]:
return (
f"The task {message} is not defined in your tasks. "
"Ensure all tasks referenced in your code are defined in the tasks.yaml file."
)
case (_, _, _):
log.debug(
f"Unhandled exception; if this is a common error, consider adding it to "
f"`cli.run._format_friendly_error_message`. Exception: {exception}"
)
raise exception # re-raise the original exception so we preserve context
from agentstack import run


def _import_project_module(path: Path):
"""
Import `main` from the project path.

We do it this way instead of spawning a subprocess so that we can share
state with the user's project.
"""
spec = importlib.util.spec_from_file_location(MAIN_MODULE_NAME, str(path / MAIN_FILENAME))

assert spec is not None # appease type checker
assert spec.loader is not None # appease type checker

project_module = importlib.util.module_from_spec(spec)
sys.path.insert(0, str((path / MAIN_FILENAME).parent))
spec.loader.exec_module(project_module)
return project_module


def run_project(command: str = 'run', cli_args: Optional[List[str]] = None):
def run_project(command: str = 'run', cli_args: Optional[list[str]] = None):
"""Validate that the project is ready to run and then run it."""
conf.assert_project()
verify_agentstack_project()

if conf.get_framework() not in frameworks.SUPPORTED_FRAMEWORKS:
raise ValidationError(f"Framework {conf.get_framework()} is not supported by agentstack.")

try:
frameworks.validate_project()
except ValidationError as e:
raise e
run.preflight()

# Parse extra --input-* arguments for runtime overrides of the project's inputs
if cli_args:
Expand All @@ -116,21 +17,5 @@ def run_project(command: str = 'run', cli_args: Optional[List[str]] = None):
log.debug(f"Using CLI input override: {key}={value}")
inputs.add_input_for_run(key, value)

load_dotenv(Path.home() / '.env') # load the user's .env file
load_dotenv(conf.PATH / '.env', override=True) # load the project's .env file
run.run_project(command=command)

# import src/main.py from the project path and run `command` from the project's main.py
try:
log.notify("Running your agent...")
project_main = _import_project_module(conf.PATH)
main = getattr(project_main, command)

# handle both async and sync entrypoints
if asyncio.iscoroutinefunction(main):
asyncio.run(main())
else:
main()
except ImportError as e:
raise ValidationError(f"Failed to import AgentStack project at: {conf.PATH.absolute()}\n{e}")
except Exception as e:
raise Exception(format_friendly_error_message(e))
41 changes: 38 additions & 3 deletions agentstack/frameworks/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from typing import Optional, Union, Protocol, Callable
from typing import overload, runtime_checkable
from typing import Optional, Union, Protocol, Callable, Generator
from types import ModuleType
from abc import ABCMeta, abstractmethod
from importlib import import_module
from dataclasses import dataclass
from pathlib import Path
import ast
from agentstack import conf
from agentstack import conf, log
from agentstack.exceptions import ValidationError
from agentstack.generation import InsertionPoint
from agentstack.utils import get_framework
Expand All @@ -21,6 +22,7 @@
LANGGRAPH = 'langgraph'
OPENAI_SWARM = 'openai_swarm'
LLAMAINDEX = 'llamaindex'
CUSTOM = 'custom'
SUPPORTED_FRAMEWORKS = [
CREWAI,
LANGGRAPH,
Expand Down Expand Up @@ -110,6 +112,22 @@ def get_graph(self) -> list[graph.Edge]:
...


@runtime_checkable
class EntrypointProtocol(Protocol):
"""
Protocol defining the interface for a framework's entrypoint file.
"""
@overload
def run(self, inputs: dict[str, str]) -> None:
"""Run the entrypoint."""
...

@overload
def run(self, inputs: dict[str, str]) -> Generator[tuple[str, str], None, None]:
"""Run the entrypoint."""
...


class BaseEntrypointFile(asttools.File, metaclass=ABCMeta):
"""
This handles interactions with a Framework's entrypoint file that are common
Expand Down Expand Up @@ -169,7 +187,7 @@ def add_import(self, module_name: str, attributes: str):
def get_base_class(self) -> ast.ClassDef:
"""
A base class is the first class inside of the file that follows the
naming convention: `<FooBar>Graph`
naming convention defined by `base_class_pattern`.
"""
pattern = self.base_class_pattern
try:
Expand Down Expand Up @@ -296,6 +314,9 @@ def get_framework_module(framework: str) -> FrameworkModule:
"""
Get the module for a framework.
"""
if framework == CUSTOM:
raise Exception("Custom frameworks do not support modification.")

try:
return import_module(f".{framework}", package=__package__)
except ImportError:
Expand All @@ -315,6 +336,11 @@ def validate_project():
Validate that the user's project is ready to run.
"""
framework = get_framework()

if framework == CUSTOM:
log.debug("Skipping validation for custom framework.")
return

entrypoint_path = get_entrypoint_path(framework)
module = get_framework_module(framework)
entrypoint = module.get_entrypoint()
Expand Down Expand Up @@ -359,6 +385,15 @@ def validate_project():
for task_name in get_all_task_names():
if task_name not in task_method_names:
raise ValidationError(f"Task `{task_name}` defined in tasks.yaml but not in {entrypoint_path}")

# Verify that the entrypoint class follows the EntrypointProtocol definition
# TODO we need to actually import the user's code to reference the entrypoint class
# EntrypointClass =
# if not isinstance(EntrypointClass, EntrypointProtocol):
# raise ValidationError(
# f"Entrypoint class `{EntrypointClass.__name__}` does not follow the "
# "EntrypointProtocol definition."
# )


def add_tool(tool: ToolConfig, agent_name: str):
Expand Down
4 changes: 4 additions & 0 deletions agentstack/inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,8 @@ def get_inputs() -> dict:
"""
Get the inputs configuration file and override with run_inputs.
"""
global run_inputs

config = InputsConfig().to_dict()
# run_inputs override saved inputs
for key, value in run_inputs.items():
Expand All @@ -89,4 +91,6 @@ def add_input_for_run(key: str, value: str):
Add an input override for the current run.
This is used by the CLI to allow inputs to be set at runtime.
"""
global run_inputs

run_inputs[key] = value
Loading