Skip to content

Commit 14ef28c

Browse files
committed
chore: cleanup and refactor
1 parent 1164083 commit 14ef28c

File tree

9 files changed

+349
-351
lines changed

9 files changed

+349
-351
lines changed
Lines changed: 35 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,12 @@
11
import asyncio
22
import json
3-
from enum import Enum
43
from uuid import UUID
54

65
from aiohttp import web
76
from aiohttp_sse import sse_response
8-
9-
# import grpc
107
from forestadmin.agent_rpc.options import RpcOptions
11-
12-
# from forestadmin.agent_rpc.services.datasource import DatasourceService
138
from forestadmin.agent_toolkit.agent import Agent
9+
from forestadmin.agent_toolkit.options import Options
1410
from forestadmin.datasource_toolkit.interfaces.fields import PrimitiveType
1511
from forestadmin.datasource_toolkit.utils.schema import SchemaUtils
1612
from forestadmin.rpc_common.hmac import is_valid_hmac
@@ -30,48 +26,19 @@
3026
from forestadmin.rpc_common.serializers.schema.schema import SchemaSerializer
3127
from forestadmin.rpc_common.serializers.utils import CallerSerializer
3228

33-
# from concurrent import futures
34-
35-
36-
# from forestadmin.rpc_common.proto import datasource_pb2_grpc
37-
38-
39-
class RcpJsonEncoder(json.JSONEncoder):
40-
def default(self, o):
41-
if isinstance(o, Enum):
42-
return o.value
43-
if isinstance(o, set):
44-
return list(sorted(o, key=lambda x: x.value if isinstance(x, Enum) else str(x)))
45-
if isinstance(o, set):
46-
return list(sorted(o, key=lambda x: x.value if isinstance(x, Enum) else str(x)))
47-
48-
try:
49-
return super().default(o)
50-
except Exception as exc:
51-
print(f"error on seriliaze {o}, {type(o)}: {exc}")
52-
5329

5430
class RpcAgent(Agent):
55-
# TODO: options to add:
56-
# * listen addr
5731
def __init__(self, options: RpcOptions):
58-
# self.server = grpc.server(futures.ThreadPoolExecutor(max_workers=1))
59-
# self.server = grpc.aio.server()
6032
self.listen_addr, self.listen_port = options["listen_addr"].rsplit(":", 1)
33+
agent_options: Options = {**options} # type:ignore
34+
agent_options["skip_schema_update"] = True
35+
agent_options["env_secret"] = "f" * 64
36+
agent_options["server_url"] = "http://fake"
37+
agent_options["schema_path"] = "./.forestadmin-schema.json"
38+
super().__init__(agent_options)
39+
6140
self.app = web.Application(middlewares=[self.hmac_middleware])
62-
# self.server.add_insecure_port(options["listen_addr"])
63-
options["skip_schema_update"] = True
64-
options["env_secret"] = "f" * 64
65-
options["server_url"] = "http://fake"
66-
# options["auth_secret"] = "f48186505a3c5d62c27743126d6a76c1dd8b3e2d8897de19"
67-
options["schema_path"] = "./.forestadmin-schema.json"
68-
super().__init__(options)
69-
70-
self.aes_key = self.options["auth_secret"][:16].encode()
71-
self.aes_iv = self.options["auth_secret"][-16:].encode()
72-
self._server_stop = False
7341
self.setup_routes()
74-
# signal.signal(signal.SIGUSR1, self.stop_handler)
7542

7643
@web.middleware
7744
async def hmac_middleware(self, request: web.Request, handler):
@@ -80,11 +47,10 @@ async def hmac_middleware(self, request: web.Request, handler):
8047
if not is_valid_hmac(
8148
self.options["auth_secret"].encode(), body, request.headers.get("X-FOREST-HMAC", "").encode("utf-8")
8249
):
83-
return web.Response(status=401)
50+
return web.Response(status=401, text="Unauthorized from HMAC verification")
8451
return await handler(request)
8552

8653
def setup_routes(self):
87-
# self.app.middlewares.append(self.hmac_middleware)
8854
self.app.router.add_route("GET", "/sse", self.sse_handler)
8955
self.app.router.add_route("GET", "/schema", self.schema)
9056
self.app.router.add_route("POST", "/collection/list", self.collection_list)
@@ -98,11 +64,11 @@ def setup_routes(self):
9864

9965
self.app.router.add_route("POST", "/execute-native-query", self.native_query)
10066
self.app.router.add_route("POST", "/render-chart", self.render_chart)
101-
self.app.router.add_route("GET", "/", lambda _: web.Response(text="OK"))
67+
self.app.router.add_route("GET", "/", lambda _: web.Response(text="OK")) # type: ignore
10268

10369
async def sse_handler(self, request: web.Request) -> web.StreamResponse:
10470
async with sse_response(request) as resp:
105-
while resp.is_connected() and not self._server_stop:
71+
while resp.is_connected():
10672
await resp.send("", event="heartbeat")
10773
await asyncio.sleep(1)
10874
data = json.dumps({"event": "RpcServerStop"})
@@ -112,44 +78,42 @@ async def sse_handler(self, request: web.Request) -> web.StreamResponse:
11278
async def schema(self, request):
11379
await self.customizer.get_datasource()
11480

115-
return web.Response(text=json.dumps(await SchemaSerializer(await self.customizer.get_datasource()).serialize()))
81+
return web.json_response(await SchemaSerializer(await self.customizer.get_datasource()).serialize())
11682

11783
async def collection_list(self, request: web.Request):
11884
body_params = await request.json()
11985
ds = await self.customizer.get_datasource()
12086
collection = ds.get_collection(body_params["collectionName"])
12187
caller = CallerSerializer.deserialize(body_params["caller"])
122-
filter_ = PaginatedFilterSerializer.deserialize(body_params["filter"], collection)
88+
filter_ = PaginatedFilterSerializer.deserialize(body_params["filter"], collection) # type:ignore
12389
projection = ProjectionSerializer.deserialize(body_params["projection"])
12490

12591
records = await collection.list(caller, filter_, projection)
12692
records = [RecordSerializer.serialize(record) for record in records]
127-
return web.Response(text=aes_encrypt(json.dumps(records), self.aes_key, self.aes_iv))
93+
return web.json_response(records)
12894

12995
async def collection_create(self, request: web.Request):
13096
body_params = await request.text()
131-
body_params = aes_decrypt(body_params, self.aes_key, self.aes_iv)
13297
body_params = json.loads(body_params)
13398
ds = await self.customizer.get_datasource()
13499

135100
collection = ds.get_collection(body_params["collectionName"])
136101
caller = CallerSerializer.deserialize(body_params["caller"])
137-
data = [RecordSerializer.deserialize(r, collection) for r in body_params["data"]]
102+
data = [RecordSerializer.deserialize(r, collection) for r in body_params["data"]] # type:ignore
138103

139104
records = await collection.create(caller, data)
140105
records = [RecordSerializer.serialize(record) for record in records]
141106
return web.json_response(records)
142107

143108
async def collection_update(self, request: web.Request):
144109
body_params = await request.text()
145-
body_params = aes_decrypt(body_params, self.aes_key, self.aes_iv)
146110
body_params = json.loads(body_params)
147111

148112
ds = await self.customizer.get_datasource()
149113
collection = ds.get_collection(body_params["collectionName"])
150114
caller = CallerSerializer.deserialize(body_params["caller"])
151-
filter_ = FilterSerializer.deserialize(body_params["filter"], collection)
152-
patch = RecordSerializer.deserialize(body_params["patch"], collection)
115+
filter_ = FilterSerializer.deserialize(body_params["filter"], collection) # type:ignore
116+
patch = RecordSerializer.deserialize(body_params["patch"], collection) # type:ignore
153117

154118
await collection.update(caller, filter_, patch)
155119
return web.Response(text="OK")
@@ -159,7 +123,7 @@ async def collection_delete(self, request: web.Request):
159123
ds = await self.customizer.get_datasource()
160124
collection = ds.get_collection(body_params["collectionName"])
161125
caller = CallerSerializer.deserialize(body_params["caller"])
162-
filter_ = FilterSerializer.deserialize(body_params["filter"], collection)
126+
filter_ = FilterSerializer.deserialize(body_params["filter"], collection) # type:ignore
163127

164128
await collection.delete(caller, filter_)
165129
return web.Response(text="OK")
@@ -169,53 +133,51 @@ async def collection_aggregate(self, request: web.Request):
169133
ds = await self.customizer.get_datasource()
170134
collection = ds.get_collection(body_params["collectionName"])
171135
caller = CallerSerializer.deserialize(body_params["caller"])
172-
filter_ = FilterSerializer.deserialize(body_params["filter"], collection)
136+
filter_ = FilterSerializer.deserialize(body_params["filter"], collection) # type:ignore
173137
aggregation = AggregationSerializer.deserialize(body_params["aggregation"])
174138

175139
records = await collection.aggregate(caller, filter_, aggregation)
176-
# records = [RecordSerializer.serialize(record) for record in records]
177-
return web.Response(text=aes_encrypt(json.dumps(records), self.aes_key, self.aes_iv))
140+
return web.json_response(records)
178141

179142
async def collection_get_form(self, request: web.Request):
180143
body_params = await request.text()
181-
body_params = aes_decrypt(body_params, self.aes_key, self.aes_iv)
182144
body_params = json.loads(body_params)
183145

184146
ds = await self.customizer.get_datasource()
185147
collection = ds.get_collection(body_params["collectionName"])
186148

187-
caller = CallerSerializer.deserialize(body_params["caller"]) if body_params["caller"] else None
149+
caller = CallerSerializer.deserialize(body_params["caller"])
188150
action_name = body_params["actionName"]
189-
filter_ = FilterSerializer.deserialize(body_params["filter"], collection) if body_params["filter"] else None
151+
if body_params["filter"]:
152+
filter_ = FilterSerializer.deserialize(body_params["filter"], collection) # type:ignore
153+
else:
154+
filter_ = None
190155
data = ActionFormValuesSerializer.deserialize(body_params["data"])
191156
meta = body_params["meta"]
192157

193158
form = await collection.get_form(caller, action_name, data, filter_, meta)
194-
return web.Response(
195-
text=aes_encrypt(json.dumps(ActionFormSerializer.serialize(form)), self.aes_key, self.aes_iv)
196-
)
159+
return web.json_response(ActionFormSerializer.serialize(form))
197160

198161
async def collection_execute(self, request: web.Request):
199162
body_params = await request.text()
200-
body_params = aes_decrypt(body_params, self.aes_key, self.aes_iv)
201163
body_params = json.loads(body_params)
202164

203165
ds = await self.customizer.get_datasource()
204166
collection = ds.get_collection(body_params["collectionName"])
205167

206-
caller = CallerSerializer.deserialize(body_params["caller"]) if body_params["caller"] else None
168+
caller = CallerSerializer.deserialize(body_params["caller"])
207169
action_name = body_params["actionName"]
208-
filter_ = FilterSerializer.deserialize(body_params["filter"], collection) if body_params["filter"] else None
170+
if body_params["filter"]:
171+
filter_ = FilterSerializer.deserialize(body_params["filter"], collection) # type:ignore
172+
else:
173+
filter_ = None
209174
data = ActionFormValuesSerializer.deserialize(body_params["data"])
210175

211176
result = await collection.execute(caller, action_name, data, filter_)
212-
return web.Response(
213-
text=aes_encrypt(json.dumps(ActionResultSerializer.serialize(result)), self.aes_key, self.aes_iv)
214-
)
177+
return web.json_response(ActionResultSerializer.serialize(result)) # type:ignore
215178

216179
async def collection_render_chart(self, request: web.Request):
217180
body_params = await request.text()
218-
body_params = aes_decrypt(body_params, self.aes_key, self.aes_iv)
219181
body_params = json.loads(body_params)
220182

221183
ds = await self.customizer.get_datasource()
@@ -244,11 +206,10 @@ async def collection_render_chart(self, request: web.Request):
244206
ret.append(value)
245207

246208
result = await collection.render_chart(caller, name, record_id)
247-
return web.Response(text=aes_encrypt(json.dumps(result), self.aes_key, self.aes_iv))
209+
return web.json_response(result)
248210

249211
async def render_chart(self, request: web.Request):
250212
body_params = await request.text()
251-
body_params = aes_decrypt(body_params, self.aes_key, self.aes_iv)
252213
body_params = json.loads(body_params)
253214

254215
ds = await self.customizer.get_datasource()
@@ -257,11 +218,10 @@ async def render_chart(self, request: web.Request):
257218
name = body_params["name"]
258219

259220
result = await ds.render_chart(caller, name)
260-
return web.Response(text=aes_encrypt(json.dumps(result), self.aes_key, self.aes_iv))
221+
return web.json_response(result)
261222

262223
async def native_query(self, request: web.Request):
263224
body_params = await request.text()
264-
body_params = aes_decrypt(body_params, self.aes_key, self.aes_iv)
265225
body_params = json.loads(body_params)
266226

267227
ds = await self.customizer.get_datasource()
@@ -270,7 +230,7 @@ async def native_query(self, request: web.Request):
270230
parameters = body_params["parameters"]
271231

272232
result = await ds.execute_native_query(connection_name, native_query, parameters)
273-
return web.Response(text=aes_encrypt(json.dumps(result), self.aes_key, self.aes_iv))
233+
return web.json_response(result)
274234

275235
def start(self):
276236
web.run_app(self.app, host=self.listen_addr, port=int(self.listen_port))

0 commit comments

Comments
 (0)