Skip to content

feat(datasource_rpc): add_datasource_rpc #321

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 34 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
26a717c
chore: add proto files
jbarreau Mar 12, 2025
e53cace
feat(agent): add a reload method to the agent
jbarreau Mar 12, 2025
3ea33fb
chore: cleanup agent reload code
jbarreau Mar 14, 2025
ff934d1
fix(rename_collection): issues while reloading
jbarreau Mar 14, 2025
036a139
fix(validation): validation rules were duplicated in schema
jbarreau Mar 14, 2025
c39355c
fix(live_queries): connections weren't refresh correctly
jbarreau Mar 14, 2025
8253318
chore: improve naming
jbarreau Mar 14, 2025
185969e
chore: add common rpc package
jbarreau Mar 20, 2025
8c62df0
chore: remove forgotten code
jbarreau Mar 20, 2025
678ec10
fix: add match operator forgotten from literal operators
jbarreau Mar 20, 2025
f4afd08
chore: add piece of datasource and rpc agent
jbarreau Mar 20, 2025
d9ca0b9
chore: all methods now exists in rpc datasource/agent
jbarreau Mar 25, 2025
e2856c3
chore: more tolerant with aggregation typing
jbarreau Mar 25, 2025
bb2bcc5
fix(django): don't create agent in the case it won't start
jbarreau Mar 25, 2025
dadea2d
chore: fix linting
jbarreau Mar 25, 2025
e104ee3
chore: add serializer for files in actions
jbarreau Mar 26, 2025
6968375
chore: simplify the reloading
jbarreau Mar 26, 2025
aedaf69
chore: cleanup old grpc files
jbarreau Mar 26, 2025
e9ba2e9
chore: remove json dump and load from serializer
jbarreau Mar 26, 2025
ab1b218
chore: continue cleanup
jbarreau Mar 26, 2025
031d6d5
chore: forgot to commit this
jbarreau Mar 26, 2025
96cf7af
chore: add segments from rpc datasource
jbarreau Mar 26, 2025
bb1cd9d
chore: handle searchable
jbarreau Mar 27, 2025
6e10e3e
chore: second solution of search
jbarreau Mar 27, 2025
8eec33d
chore: finalize search behavior
jbarreau Mar 27, 2025
76e4c19
chore: cleanup and refactor
jbarreau Apr 2, 2025
51f963d
chore: rpc route are the same as ruby
jbarreau Apr 2, 2025
5171eec
chore: simplify PR diff
jbarreau Apr 7, 2025
4018b83
chore: same hmac and urls as ruby
jbarreau Apr 8, 2025
93f2872
chore: remove old files
jbarreau Apr 8, 2025
42d5474
chore: be concilient if no live query connections
jbarreau Apr 8, 2025
efa4124
chore: missing attribute on backup stack
jbarreau Apr 8, 2025
fd7687f
chore(rpc_serilization): use same case as ruby
jbarreau Apr 8, 2025
818d283
chore: cleanup old code
jbarreau Apr 9, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,9 @@ src/datasource_toolkit/poetry.lock
src/flask_agent/poetry.lock
src/django_agent/poetry.lock
src/datasource_django/poetry.lock
src/datasource_rpc/poetry.lock
src/agent_rpc/poetry.lock
src/rpc_common/poetry.lock

# generate file during tests
src/django_agent/tests/test_project_agent/.forestadmin-schema.json
Empty file added src/agent_rpc/README.md
Empty file.
Empty file.
242 changes: 242 additions & 0 deletions src/agent_rpc/forestadmin/agent_rpc/agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,242 @@
import asyncio
import json
from uuid import UUID

from aiohttp import web
from aiohttp_sse import sse_response
from forestadmin.agent_rpc.hmac_middleware import HmacValidationError, HmacValidator
from forestadmin.agent_rpc.options import RpcOptions
from forestadmin.agent_toolkit.agent import Agent
from forestadmin.agent_toolkit.options import Options
from forestadmin.datasource_toolkit.interfaces.fields import PrimitiveType
from forestadmin.datasource_toolkit.utils.schema import SchemaUtils
from forestadmin.rpc_common.serializers.actions import (
ActionFormSerializer,
ActionFormValuesSerializer,
ActionResultSerializer,
)
from forestadmin.rpc_common.serializers.collection.aggregation import AggregationSerializer
from forestadmin.rpc_common.serializers.collection.filter import (
FilterSerializer,
PaginatedFilterSerializer,
ProjectionSerializer,
)
from forestadmin.rpc_common.serializers.collection.record import RecordSerializer
from forestadmin.rpc_common.serializers.schema.schema import SchemaSerializer
from forestadmin.rpc_common.serializers.utils import CallerSerializer


class RpcAgent(Agent):
def __init__(self, options: RpcOptions):
self.listen_addr, self.listen_port = options["listen_addr"].rsplit(":", 1)
agent_options: Options = {**options} # type:ignore
agent_options["skip_schema_update"] = True
agent_options["env_secret"] = "f" * 64
agent_options["server_url"] = "http://fake"
agent_options["schema_path"] = "./.forestadmin-schema.json"
super().__init__(agent_options)
self.hmac_validator = HmacValidator(options["auth_secret"])

self.app = web.Application(middlewares=[self.hmac_middleware])
self.setup_routes()

@web.middleware
async def hmac_middleware(self, request: web.Request, handler):
# TODO: hmc on SSE ?
if request.url.path in ["/", "/forest/rpc/sse"]:
return await handler(request)

header_sign = request.headers.get("X_SIGNATURE")
header_timestamp = request.headers.get("X_TIMESTAMP")
try:
self.hmac_validator.validate_hmac(header_sign, header_timestamp)
except HmacValidationError:
return web.Response(status=401, text="Unauthorized from HMAC verification")
return await handler(request)

def setup_routes(self):
self.app.router.add_route("GET", "/", lambda _: web.Response(text="OK")) # type: ignore
self.app.router.add_route("GET", "/forest/rpc/sse", self.sse_handler)
self.app.router.add_route("GET", "/forest/rpc-schema", self.schema)

# self.app.router.add_route("POST", "/execute-native-query", self.native_query)
self.app.router.add_route("POST", "/forest/rpc/datasource-chart", self.render_chart)

self.app.router.add_route("POST", "/forest/rpc/{collection_name}/list", self.collection_list)
self.app.router.add_route("POST", "/forest/rpc/{collection_name}/create", self.collection_create)
self.app.router.add_route("POST", "/forest/rpc/{collection_name}/update", self.collection_update)
self.app.router.add_route("POST", "/forest/rpc/{collection_name}/delete", self.collection_delete)
self.app.router.add_route("POST", "/forest/rpc/{collection_name}/aggregate", self.collection_aggregate)
self.app.router.add_route("POST", "/forest/rpc/{collection_name}/action-form", self.collection_get_form)
self.app.router.add_route("POST", "/forest/rpc/{collection_name}/action-execute", self.collection_execute)
self.app.router.add_route("POST", "/forest/rpc/{collection_name}/chart", self.collection_render_chart)

async def sse_handler(self, request: web.Request) -> web.StreamResponse:
async with sse_response(request) as resp:
while resp.is_connected():
await resp.send("", event="heartbeat")
await asyncio.sleep(1)
data = json.dumps({"event": "RpcServerStop"})
await resp.send(data, event="RpcServerStop")
return resp

async def schema(self, request):
await self.customizer.get_datasource()

return web.json_response(await SchemaSerializer(await self.customizer.get_datasource()).serialize())

async def collection_list(self, request: web.Request):
body_params = await request.json()
ds = await self.customizer.get_datasource()
collection = ds.get_collection(request.match_info["collection_name"])
caller = CallerSerializer.deserialize(body_params["caller"])
filter_ = PaginatedFilterSerializer.deserialize(body_params["filter"], collection) # type:ignore
projection = ProjectionSerializer.deserialize(body_params["projection"])

records = await collection.list(caller, filter_, projection)
records = [RecordSerializer.serialize(record) for record in records]
return web.json_response(records)

async def collection_create(self, request: web.Request):
body_params = await request.text()
body_params = json.loads(body_params)
ds = await self.customizer.get_datasource()

collection = ds.get_collection(request.match_info["collection_name"])
caller = CallerSerializer.deserialize(body_params["caller"])
data = [RecordSerializer.deserialize(r, collection) for r in body_params["data"]] # type:ignore

records = await collection.create(caller, data)
records = [RecordSerializer.serialize(record) for record in records]
return web.json_response(records)

async def collection_update(self, request: web.Request):
body_params = await request.text()
body_params = json.loads(body_params)

ds = await self.customizer.get_datasource()
collection = ds.get_collection(request.match_info["collection_name"])
caller = CallerSerializer.deserialize(body_params["caller"])
filter_ = FilterSerializer.deserialize(body_params["filter"], collection) # type:ignore
patch = RecordSerializer.deserialize(body_params["patch"], collection) # type:ignore

await collection.update(caller, filter_, patch)
return web.Response(text="OK")

async def collection_delete(self, request: web.Request):
body_params = await request.json()
ds = await self.customizer.get_datasource()
collection = ds.get_collection(request.match_info["collection_name"])
caller = CallerSerializer.deserialize(body_params["caller"])
filter_ = FilterSerializer.deserialize(body_params["filter"], collection) # type:ignore

await collection.delete(caller, filter_)
return web.Response(text="OK")

async def collection_aggregate(self, request: web.Request):
body_params = await request.json()
ds = await self.customizer.get_datasource()
collection = ds.get_collection(request.match_info["collection_name"])
caller = CallerSerializer.deserialize(body_params["caller"])
filter_ = FilterSerializer.deserialize(body_params["filter"], collection) # type:ignore
aggregation = AggregationSerializer.deserialize(body_params["aggregation"])

records = await collection.aggregate(caller, filter_, aggregation)
return web.json_response(records)

async def collection_get_form(self, request: web.Request):
body_params = await request.text()
body_params = json.loads(body_params)

ds = await self.customizer.get_datasource()
collection = ds.get_collection(request.match_info["collection_name"])

caller = CallerSerializer.deserialize(body_params["caller"])
action_name = body_params["actionName"]
if body_params["filter"]:
filter_ = FilterSerializer.deserialize(body_params["filter"], collection) # type:ignore
else:
filter_ = None
data = ActionFormValuesSerializer.deserialize(body_params["data"])
meta = body_params["meta"]

form = await collection.get_form(caller, action_name, data, filter_, meta)
return web.json_response(ActionFormSerializer.serialize(form))

async def collection_execute(self, request: web.Request):
body_params = await request.text()
body_params = json.loads(body_params)

ds = await self.customizer.get_datasource()
collection = ds.get_collection(request.match_info["collection_name"])

caller = CallerSerializer.deserialize(body_params["caller"])
action_name = body_params["actionName"]
if body_params["filter"]:
filter_ = FilterSerializer.deserialize(body_params["filter"], collection) # type:ignore
else:
filter_ = None
data = ActionFormValuesSerializer.deserialize(body_params["data"])

result = await collection.execute(caller, action_name, data, filter_)
return web.json_response(ActionResultSerializer.serialize(result)) # type:ignore

async def collection_render_chart(self, request: web.Request):
body_params = await request.text()
body_params = json.loads(body_params)

ds = await self.customizer.get_datasource()
collection = ds.get_collection(request.match_info["collection_name"])

caller = CallerSerializer.deserialize(body_params["caller"])
name = body_params["name"]
record_id = body_params["recordId"]
ret = []
for i, value in enumerate(record_id):
type_record_id = collection.schema["fields"][SchemaUtils.get_primary_keys(collection.schema)[i]][
"column_type"
]

if type_record_id == PrimitiveType.DATE:
ret.append(value.fromisoformat())
elif type_record_id == PrimitiveType.DATE_ONLY:
ret.append(value.fromisoformat())
elif type_record_id == PrimitiveType.DATE:
ret.append(value.fromisoformat())
elif type_record_id == PrimitiveType.POINT:
ret.append((value[0], value[1]))
elif type_record_id == PrimitiveType.UUID:
ret.append(UUID(value))
else:
ret.append(value)

result = await collection.render_chart(caller, name, record_id)
return web.json_response(result)

async def render_chart(self, request: web.Request):
body_params = await request.text()
body_params = json.loads(body_params)

ds = await self.customizer.get_datasource()

caller = CallerSerializer.deserialize(body_params["caller"])
name = body_params["name"]

result = await ds.render_chart(caller, name)
return web.json_response(result)

# TODO: speak about; it's currently not implemented in ruby
# async def native_query(self, request: web.Request):
# body_params = await request.text()
# body_params = json.loads(body_params)

# ds = await self.customizer.get_datasource()
# connection_name = body_params["connectionName"]
# native_query = body_params["nativeQuery"]
# parameters = body_params["parameters"]

# result = await ds.execute_native_query(connection_name, native_query, parameters)
# return web.json_response(result)

def start(self):
web.run_app(self.app, host=self.listen_addr, port=int(self.listen_port))
56 changes: 56 additions & 0 deletions src/agent_rpc/forestadmin/agent_rpc/hmac_middleware.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import datetime

from forestadmin.datasource_toolkit.exceptions import ForestException
from forestadmin.rpc_common.hmac import generate_hmac, is_valid_hmac


class HmacValidationError(ForestException):
pass


class HmacValidator:
ALLOWED_TIME_DIFF = 300
SIGNATURE_REUSE_WINDOW = 5

def __init__(self, secret_key: str) -> None:
self.secret_key = secret_key
self.used_signatures = dict()

def validate_hmac(self, sign, timestamp):
"""Validate the HMAC signature."""
if not sign or not timestamp:
raise HmacValidationError("Missing HMAC signature or timestamp")

self.validate_timestamp(timestamp)

expected_sign = generate_hmac(self.secret_key.encode("utf-8"), timestamp.encode("utf-8"))
if not is_valid_hmac(self.secret_key.encode("utf-8"), timestamp.encode("utf-8"), expected_sign.encode("utf-8")):
raise HmacValidationError("Invalid HMAC signature")

if sign in self.used_signatures.keys():
last_used = self.used_signatures[sign]
if (datetime.datetime.now(datetime.timezone.utc) - last_used).total_seconds() > self.SIGNATURE_REUSE_WINDOW:
raise HmacValidationError("HMAC signature has already been used")

self.used_signatures[sign] = datetime.datetime.now(datetime.timezone.utc)
self._cleanup_old_signs()
return True

def validate_timestamp(self, timestamp):
try:
current_time = datetime.datetime.fromisoformat(timestamp)
except Exception:
raise HmacValidationError("Invalid timestamp format")

if (datetime.datetime.now(datetime.timezone.utc) - current_time).total_seconds() > self.ALLOWED_TIME_DIFF:
raise HmacValidationError("Timestamp is too old or in the future")

def _cleanup_old_signs(self):
now = datetime.datetime.now(datetime.timezone.utc)
to_rm = []
for sign, last_used in self.used_signatures.items():
if (now - last_used).total_seconds() > self.ALLOWED_TIME_DIFF:
to_rm.append(sign)

for sign in to_rm:
del self.used_signatures[sign]
6 changes: 6 additions & 0 deletions src/agent_rpc/forestadmin/agent_rpc/options.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from typing_extensions import TypedDict


class RpcOptions(TypedDict):
listen_addr: str
auth_secret: str
28 changes: 28 additions & 0 deletions src/agent_rpc/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import asyncio
import logging

from forestadmin.agent_rpc.agent import RpcAgent
from forestadmin.datasource_sqlalchemy.datasource import SqlAlchemyDatasource
from sqlalchemy_models import DB_URI, Base


def main() -> None:
agent = RpcAgent({"listen_addr": "0.0.0.0:50051"})
agent.add_datasource(
SqlAlchemyDatasource(Base, DB_URI), {"rename": lambda collection_name: f"FROMRPCAGENT_{collection_name}"}
)
agent.customize_collection("FROMRPCAGENT_address").add_field(
"new_fieeeld",
{
"column_type": "String",
"dependencies": ["pk"],
"get_values": lambda records, ctx: ["v" for r in records],
},
)
agent.start()


if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
# asyncio.run(main())
main()
Loading
Loading