1
1
import asyncio
2
2
import json
3
- from enum import Enum
4
3
from uuid import UUID
5
4
6
5
from aiohttp import web
7
6
from aiohttp_sse import sse_response
8
-
9
- # import grpc
10
7
from forestadmin .agent_rpc .options import RpcOptions
11
-
12
- # from forestadmin.agent_rpc.services.datasource import DatasourceService
13
8
from forestadmin .agent_toolkit .agent import Agent
9
+ from forestadmin .agent_toolkit .options import Options
14
10
from forestadmin .datasource_toolkit .interfaces .fields import PrimitiveType
15
11
from forestadmin .datasource_toolkit .utils .schema import SchemaUtils
16
12
from forestadmin .rpc_common .hmac import is_valid_hmac
30
26
from forestadmin .rpc_common .serializers .schema .schema import SchemaSerializer
31
27
from forestadmin .rpc_common .serializers .utils import CallerSerializer
32
28
33
- # from concurrent import futures
34
-
35
-
36
- # from forestadmin.rpc_common.proto import datasource_pb2_grpc
37
-
38
-
39
- class RcpJsonEncoder (json .JSONEncoder ):
40
- def default (self , o ):
41
- if isinstance (o , Enum ):
42
- return o .value
43
- if isinstance (o , set ):
44
- return list (sorted (o , key = lambda x : x .value if isinstance (x , Enum ) else str (x )))
45
- if isinstance (o , set ):
46
- return list (sorted (o , key = lambda x : x .value if isinstance (x , Enum ) else str (x )))
47
-
48
- try :
49
- return super ().default (o )
50
- except Exception as exc :
51
- print (f"error on seriliaze { o } , { type (o )} : { exc } " )
52
-
53
29
54
30
class RpcAgent (Agent ):
55
- # TODO: options to add:
56
- # * listen addr
57
31
def __init__ (self , options : RpcOptions ):
58
- # self.server = grpc.server(futures.ThreadPoolExecutor(max_workers=1))
59
- # self.server = grpc.aio.server()
60
32
self .listen_addr , self .listen_port = options ["listen_addr" ].rsplit (":" , 1 )
33
+ agent_options : Options = {** options } # type:ignore
34
+ agent_options ["skip_schema_update" ] = True
35
+ agent_options ["env_secret" ] = "f" * 64
36
+ agent_options ["server_url" ] = "http://fake"
37
+ agent_options ["schema_path" ] = "./.forestadmin-schema.json"
38
+ super ().__init__ (agent_options )
39
+
61
40
self .app = web .Application (middlewares = [self .hmac_middleware ])
62
- # self.server.add_insecure_port(options["listen_addr"])
63
- options ["skip_schema_update" ] = True
64
- options ["env_secret" ] = "f" * 64
65
- options ["server_url" ] = "http://fake"
66
- # options["auth_secret"] = "f48186505a3c5d62c27743126d6a76c1dd8b3e2d8897de19"
67
- options ["schema_path" ] = "./.forestadmin-schema.json"
68
- super ().__init__ (options )
69
-
70
- self .aes_key = self .options ["auth_secret" ][:16 ].encode ()
71
- self .aes_iv = self .options ["auth_secret" ][- 16 :].encode ()
72
- self ._server_stop = False
73
41
self .setup_routes ()
74
- # signal.signal(signal.SIGUSR1, self.stop_handler)
75
42
76
43
@web .middleware
77
44
async def hmac_middleware (self , request : web .Request , handler ):
@@ -80,11 +47,10 @@ async def hmac_middleware(self, request: web.Request, handler):
80
47
if not is_valid_hmac (
81
48
self .options ["auth_secret" ].encode (), body , request .headers .get ("X-FOREST-HMAC" , "" ).encode ("utf-8" )
82
49
):
83
- return web .Response (status = 401 )
50
+ return web .Response (status = 401 , text = "Unauthorized from HMAC verification" )
84
51
return await handler (request )
85
52
86
53
def setup_routes (self ):
87
- # self.app.middlewares.append(self.hmac_middleware)
88
54
self .app .router .add_route ("GET" , "/sse" , self .sse_handler )
89
55
self .app .router .add_route ("GET" , "/schema" , self .schema )
90
56
self .app .router .add_route ("POST" , "/collection/list" , self .collection_list )
@@ -98,11 +64,11 @@ def setup_routes(self):
98
64
99
65
self .app .router .add_route ("POST" , "/execute-native-query" , self .native_query )
100
66
self .app .router .add_route ("POST" , "/render-chart" , self .render_chart )
101
- self .app .router .add_route ("GET" , "/" , lambda _ : web .Response (text = "OK" ))
67
+ self .app .router .add_route ("GET" , "/" , lambda _ : web .Response (text = "OK" )) # type: ignore
102
68
103
69
async def sse_handler (self , request : web .Request ) -> web .StreamResponse :
104
70
async with sse_response (request ) as resp :
105
- while resp .is_connected () and not self . _server_stop :
71
+ while resp .is_connected ():
106
72
await resp .send ("" , event = "heartbeat" )
107
73
await asyncio .sleep (1 )
108
74
data = json .dumps ({"event" : "RpcServerStop" })
@@ -112,44 +78,42 @@ async def sse_handler(self, request: web.Request) -> web.StreamResponse:
112
78
async def schema (self , request ):
113
79
await self .customizer .get_datasource ()
114
80
115
- return web .Response ( text = json . dumps ( await SchemaSerializer (await self .customizer .get_datasource ()).serialize () ))
81
+ return web .json_response ( await SchemaSerializer (await self .customizer .get_datasource ()).serialize ())
116
82
117
83
async def collection_list (self , request : web .Request ):
118
84
body_params = await request .json ()
119
85
ds = await self .customizer .get_datasource ()
120
86
collection = ds .get_collection (body_params ["collectionName" ])
121
87
caller = CallerSerializer .deserialize (body_params ["caller" ])
122
- filter_ = PaginatedFilterSerializer .deserialize (body_params ["filter" ], collection )
88
+ filter_ = PaginatedFilterSerializer .deserialize (body_params ["filter" ], collection ) # type:ignore
123
89
projection = ProjectionSerializer .deserialize (body_params ["projection" ])
124
90
125
91
records = await collection .list (caller , filter_ , projection )
126
92
records = [RecordSerializer .serialize (record ) for record in records ]
127
- return web .Response ( text = aes_encrypt ( json . dumps ( records ), self . aes_key , self . aes_iv ) )
93
+ return web .json_response ( records )
128
94
129
95
async def collection_create (self , request : web .Request ):
130
96
body_params = await request .text ()
131
- body_params = aes_decrypt (body_params , self .aes_key , self .aes_iv )
132
97
body_params = json .loads (body_params )
133
98
ds = await self .customizer .get_datasource ()
134
99
135
100
collection = ds .get_collection (body_params ["collectionName" ])
136
101
caller = CallerSerializer .deserialize (body_params ["caller" ])
137
- data = [RecordSerializer .deserialize (r , collection ) for r in body_params ["data" ]]
102
+ data = [RecordSerializer .deserialize (r , collection ) for r in body_params ["data" ]] # type:ignore
138
103
139
104
records = await collection .create (caller , data )
140
105
records = [RecordSerializer .serialize (record ) for record in records ]
141
106
return web .json_response (records )
142
107
143
108
async def collection_update (self , request : web .Request ):
144
109
body_params = await request .text ()
145
- body_params = aes_decrypt (body_params , self .aes_key , self .aes_iv )
146
110
body_params = json .loads (body_params )
147
111
148
112
ds = await self .customizer .get_datasource ()
149
113
collection = ds .get_collection (body_params ["collectionName" ])
150
114
caller = CallerSerializer .deserialize (body_params ["caller" ])
151
- filter_ = FilterSerializer .deserialize (body_params ["filter" ], collection )
152
- patch = RecordSerializer .deserialize (body_params ["patch" ], collection )
115
+ filter_ = FilterSerializer .deserialize (body_params ["filter" ], collection ) # type:ignore
116
+ patch = RecordSerializer .deserialize (body_params ["patch" ], collection ) # type:ignore
153
117
154
118
await collection .update (caller , filter_ , patch )
155
119
return web .Response (text = "OK" )
@@ -159,7 +123,7 @@ async def collection_delete(self, request: web.Request):
159
123
ds = await self .customizer .get_datasource ()
160
124
collection = ds .get_collection (body_params ["collectionName" ])
161
125
caller = CallerSerializer .deserialize (body_params ["caller" ])
162
- filter_ = FilterSerializer .deserialize (body_params ["filter" ], collection )
126
+ filter_ = FilterSerializer .deserialize (body_params ["filter" ], collection ) # type:ignore
163
127
164
128
await collection .delete (caller , filter_ )
165
129
return web .Response (text = "OK" )
@@ -169,53 +133,51 @@ async def collection_aggregate(self, request: web.Request):
169
133
ds = await self .customizer .get_datasource ()
170
134
collection = ds .get_collection (body_params ["collectionName" ])
171
135
caller = CallerSerializer .deserialize (body_params ["caller" ])
172
- filter_ = FilterSerializer .deserialize (body_params ["filter" ], collection )
136
+ filter_ = FilterSerializer .deserialize (body_params ["filter" ], collection ) # type:ignore
173
137
aggregation = AggregationSerializer .deserialize (body_params ["aggregation" ])
174
138
175
139
records = await collection .aggregate (caller , filter_ , aggregation )
176
- # records = [RecordSerializer.serialize(record) for record in records]
177
- return web .Response (text = aes_encrypt (json .dumps (records ), self .aes_key , self .aes_iv ))
140
+ return web .json_response (records )
178
141
179
142
async def collection_get_form (self , request : web .Request ):
180
143
body_params = await request .text ()
181
- body_params = aes_decrypt (body_params , self .aes_key , self .aes_iv )
182
144
body_params = json .loads (body_params )
183
145
184
146
ds = await self .customizer .get_datasource ()
185
147
collection = ds .get_collection (body_params ["collectionName" ])
186
148
187
- caller = CallerSerializer .deserialize (body_params ["caller" ]) if body_params [ "caller" ] else None
149
+ caller = CallerSerializer .deserialize (body_params ["caller" ])
188
150
action_name = body_params ["actionName" ]
189
- filter_ = FilterSerializer .deserialize (body_params ["filter" ], collection ) if body_params ["filter" ] else None
151
+ if body_params ["filter" ]:
152
+ filter_ = FilterSerializer .deserialize (body_params ["filter" ], collection ) # type:ignore
153
+ else :
154
+ filter_ = None
190
155
data = ActionFormValuesSerializer .deserialize (body_params ["data" ])
191
156
meta = body_params ["meta" ]
192
157
193
158
form = await collection .get_form (caller , action_name , data , filter_ , meta )
194
- return web .Response (
195
- text = aes_encrypt (json .dumps (ActionFormSerializer .serialize (form )), self .aes_key , self .aes_iv )
196
- )
159
+ return web .json_response (ActionFormSerializer .serialize (form ))
197
160
198
161
async def collection_execute (self , request : web .Request ):
199
162
body_params = await request .text ()
200
- body_params = aes_decrypt (body_params , self .aes_key , self .aes_iv )
201
163
body_params = json .loads (body_params )
202
164
203
165
ds = await self .customizer .get_datasource ()
204
166
collection = ds .get_collection (body_params ["collectionName" ])
205
167
206
- caller = CallerSerializer .deserialize (body_params ["caller" ]) if body_params [ "caller" ] else None
168
+ caller = CallerSerializer .deserialize (body_params ["caller" ])
207
169
action_name = body_params ["actionName" ]
208
- filter_ = FilterSerializer .deserialize (body_params ["filter" ], collection ) if body_params ["filter" ] else None
170
+ if body_params ["filter" ]:
171
+ filter_ = FilterSerializer .deserialize (body_params ["filter" ], collection ) # type:ignore
172
+ else :
173
+ filter_ = None
209
174
data = ActionFormValuesSerializer .deserialize (body_params ["data" ])
210
175
211
176
result = await collection .execute (caller , action_name , data , filter_ )
212
- return web .Response (
213
- text = aes_encrypt (json .dumps (ActionResultSerializer .serialize (result )), self .aes_key , self .aes_iv )
214
- )
177
+ return web .json_response (ActionResultSerializer .serialize (result )) # type:ignore
215
178
216
179
async def collection_render_chart (self , request : web .Request ):
217
180
body_params = await request .text ()
218
- body_params = aes_decrypt (body_params , self .aes_key , self .aes_iv )
219
181
body_params = json .loads (body_params )
220
182
221
183
ds = await self .customizer .get_datasource ()
@@ -244,11 +206,10 @@ async def collection_render_chart(self, request: web.Request):
244
206
ret .append (value )
245
207
246
208
result = await collection .render_chart (caller , name , record_id )
247
- return web .Response ( text = aes_encrypt ( json . dumps ( result ), self . aes_key , self . aes_iv ) )
209
+ return web .json_response ( result )
248
210
249
211
async def render_chart (self , request : web .Request ):
250
212
body_params = await request .text ()
251
- body_params = aes_decrypt (body_params , self .aes_key , self .aes_iv )
252
213
body_params = json .loads (body_params )
253
214
254
215
ds = await self .customizer .get_datasource ()
@@ -257,11 +218,10 @@ async def render_chart(self, request: web.Request):
257
218
name = body_params ["name" ]
258
219
259
220
result = await ds .render_chart (caller , name )
260
- return web .Response ( text = aes_encrypt ( json . dumps ( result ), self . aes_key , self . aes_iv ) )
221
+ return web .json_response ( result )
261
222
262
223
async def native_query (self , request : web .Request ):
263
224
body_params = await request .text ()
264
- body_params = aes_decrypt (body_params , self .aes_key , self .aes_iv )
265
225
body_params = json .loads (body_params )
266
226
267
227
ds = await self .customizer .get_datasource ()
@@ -270,7 +230,7 @@ async def native_query(self, request: web.Request):
270
230
parameters = body_params ["parameters" ]
271
231
272
232
result = await ds .execute_native_query (connection_name , native_query , parameters )
273
- return web .Response ( text = aes_encrypt ( json . dumps ( result ), self . aes_key , self . aes_iv ) )
233
+ return web .json_response ( result )
274
234
275
235
def start (self ):
276
236
web .run_app (self .app , host = self .listen_addr , port = int (self .listen_port ))
0 commit comments