Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion examples/use_custom_dimensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ def save_dimensions(dimensions: list[dict[str, Union[str, int]]]) -> None:
def save_dimension_list(
dimensions: list[dict[str, Union[str, int]]], list_name: str
) -> None:
Migrator().run()
dimension_list = CustomEvaluationDimensionList.find(
CustomEvaluationDimensionList.name == list_name
).all()
Expand Down
81 changes: 80 additions & 1 deletion sotopia/ui/fastapi_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
RelationshipProfile,
RelationshipType,
NonStreamingSimulationStatus,
CustomEvaluationDimensionList,
CustomEvaluationDimension,
)
from sotopia.envs.parallel import ParallelSotopiaEnv
from sotopia.envs.evaluators import (
Expand All @@ -33,7 +35,7 @@
)
from typing import Optional, Any
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, model_validator, field_validator
from pydantic import BaseModel, model_validator, field_validator, Field

from sotopia.ui.websocket_utils import (
WebSocketSotopiaSimulator,
Expand Down Expand Up @@ -112,6 +114,16 @@ class EnvironmentProfileWrapper(BaseModel):
tag: str = ""


class CustomEvaluationDimensionsWrapper(BaseModel):
pk: str = ""
name: str = Field(
default="", description="The name of the custom evaluation dimension list"
)
dimensions: list[CustomEvaluationDimension] = Field(
default=[], description="The dimensions of the custom evaluation dimension list"
)


class SimulationRequest(BaseModel):
env_id: str
agent_ids: list[str]
Expand Down Expand Up @@ -228,6 +240,20 @@ async def get_episodes(get_by: Literal["id", "tag"], value: str) -> list[Episode
return episodes


@app.get(
"/evaluation_dimensions/", response_model=dict[str, list[CustomEvaluationDimension]]
)
async def get_evaluation_dimensions() -> dict[str, list[CustomEvaluationDimension]]:
custom_evaluation_dimensions: dict[str, list[CustomEvaluationDimension]] = {}
custom_evaluation_dimension_list = CustomEvaluationDimensionList.all()
for custom_evaluation_dimension_list in custom_evaluation_dimension_list:
custom_evaluation_dimensions[custom_evaluation_dimension_list.name] = [
CustomEvaluationDimension.get(pk=pk)
for pk in custom_evaluation_dimension_list.dimension_pks
]
return custom_evaluation_dimensions


@app.post("/scenarios/", response_model=str)
async def create_scenario(scenario: EnvironmentProfileWrapper) -> str:
scenario_profile = EnvironmentProfile(**scenario.model_dump())
Expand Down Expand Up @@ -255,6 +281,49 @@ async def create_relationship(relationship: RelationshipWrapper) -> str:
return pk


@app.post("/evaluation_dimensions/", response_model=str)
async def create_evaluation_dimensions(
evaluation_dimensions: CustomEvaluationDimensionsWrapper,
) -> str:
dimension_list = CustomEvaluationDimensionList.find(
CustomEvaluationDimensionList.name == evaluation_dimensions.name
).all()

if len(dimension_list) == 0:
all_dimensions_pks = []
for dimension in evaluation_dimensions.dimensions:
find_dimension = CustomEvaluationDimension.find(
CustomEvaluationDimension.name == dimension.name
).all()
if len(find_dimension) == 0:
dimension.save()
all_dimensions_pks.append(dimension.pk)
elif len(find_dimension) == 1:
all_dimensions_pks.append(find_dimension[0].pk)
else:
raise HTTPException(
status_code=409,
detail=f"Evaluation dimension with name={dimension.name} already exists",
)

custom_evaluation_dimension_list = CustomEvaluationDimensionList(
pk=evaluation_dimensions.pk,
name=evaluation_dimensions.name,
dimension_pks=all_dimensions_pks,
)
custom_evaluation_dimension_list.save()
logger.info(f"Created evaluation dimension list {evaluation_dimensions.name}")
else:
raise HTTPException(
status_code=409,
detail=f"Evaluation dimension list with name={evaluation_dimensions.name} already exists",
)

pk = custom_evaluation_dimension_list.pk
assert pk is not None
return pk


async def run_simulation(
episode_pk: str,
simulation_request: SimulationRequest,
Expand Down Expand Up @@ -426,6 +495,16 @@ async def delete_episode(episode_id: str) -> str:
return episode_id


@app.delete("/evaluation_dimensions/{evaluation_dimension_list_pk}", response_model=str)
async def delete_evaluation_dimension_list(evaluation_dimension_list_pk: str) -> str:
for dimension_pk in CustomEvaluationDimensionList.get(
evaluation_dimension_list_pk
).dimension_pks:
CustomEvaluationDimension.delete(dimension_pk)
CustomEvaluationDimensionList.delete(evaluation_dimension_list_pk)
return evaluation_dimension_list_pk


active_simulations: Dict[
str, bool
] = {} # TODO check whether this is the correct way to store the active simulations
Expand Down
65 changes: 64 additions & 1 deletion tests/ui/test_fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
AgentProfile,
EpisodeLog,
RelationshipProfile,
CustomEvaluationDimension,
CustomEvaluationDimensionList,
)
from sotopia.messages import SimpleMessage
from sotopia.ui.fastapi_server import app
Expand Down Expand Up @@ -68,7 +70,9 @@ def create_dummy_episode_log() -> None:


@pytest.fixture
def create_mock_data(for_posting: bool = False) -> Generator[None, None, None]:
def create_mock_data(request: pytest.FixtureRequest) -> Generator[None, None, None]:
for_posting = request.param if hasattr(request, "param") else False

def _create_mock_agent_profile() -> None:
AgentProfile(
first_name="John",
Expand Down Expand Up @@ -108,10 +112,26 @@ def _create_mock_relationship() -> None:
relationship=1.0,
).save()

def _create_mock_evaluation_dimension() -> None:
CustomEvaluationDimension(
pk="tmppk_evaluation_dimension",
name="test_dimension",
description="test_description",
range_high=10,
range_low=-10,
).save()
CustomEvaluationDimensionList(
pk="tmppk_evaluation_dimension_list",
name="test_dimension_list",
dimension_pks=["tmppk_evaluation_dimension"],
).save()

if not for_posting:
_create_mock_agent_profile()
_create_mock_env_profile()
_create_mock_relationship()
_create_mock_evaluation_dimension()
print("created mock data")
yield

try:
Expand Down Expand Up @@ -147,6 +167,15 @@ def _create_mock_relationship() -> None:
except Exception as e:
print(e)

try:
CustomEvaluationDimension.delete("tmppk_evaluation_dimension")
except Exception as e:
print(e)
try:
CustomEvaluationDimensionList.delete("tmppk_evaluation_dimension_list")
except Exception as e:
print(e)


def test_get_scenarios_all(create_mock_data: Callable[[], None]) -> None:
response = client.get("/scenarios")
Expand Down Expand Up @@ -221,6 +250,13 @@ def test_get_relationship(create_mock_data: Callable[[], None]) -> None:
assert response.json() == "1: know_by_name"


def test_get_evaluation_dimensions(create_mock_data: Callable[[], None]) -> None:
response = client.get("/evaluation_dimensions/")
assert response.status_code == 200
assert isinstance(response.json(), dict)
assert response.json()["test_dimension_list"][0]["name"] == "test_dimension"


@pytest.mark.parametrize("create_mock_data", [True], indirect=True)
def test_create_agent(create_mock_data: Callable[[], None]) -> None:
agent_data = {
Expand Down Expand Up @@ -260,6 +296,27 @@ def test_create_relationship(create_mock_data: Callable[[], None]) -> None:
assert isinstance(response.json(), str)


@pytest.mark.parametrize("create_mock_data", [True], indirect=True)
def test_create_evaluation_dimensions(create_mock_data: Callable[[], None]) -> None:
evaluation_dimension_data = {
"pk": "tmppk_evaluation_dimension_list",
"name": "test_dimension_list",
"dimensions": [
{
"pk": "tmppk_evaluation_dimension",
"name": "test_dimension",
"description": "test_description",
"range_high": 10,
"range_low": -10,
}
],
}
response = client.post("/evaluation_dimensions", json=evaluation_dimension_data)
print(response.json())
assert response.status_code == 200
assert isinstance(response.json(), str)


def test_delete_agent(create_mock_data: Callable[[], None]) -> None:
response = client.delete("/agents/tmppk_agent1")
assert response.status_code == 200
Expand All @@ -278,6 +335,12 @@ def test_delete_relationship(create_mock_data: Callable[[], None]) -> None:
assert isinstance(response.json(), str)


def test_delete_evaluation_dimension(create_mock_data: Callable[[], None]) -> None:
response = client.delete("/evaluation_dimensions/tmppk_evaluation_dimension_list")
assert response.status_code == 200
assert isinstance(response.json(), str)


# def test_simulate(create_mock_data: Callable[[], None]) -> None:
# response = client.post(
# "/simulate",
Expand Down
Loading