@@ -1521,6 +1521,9 @@ def handle_streaming_diffs(
15211521
15221522 return data
15231523
1524+ def run_fn_batch (self , fn , batch , fn_index , state ):
1525+ return [fn (fn_index , list (i ), state ) for i in zip (* batch )]
1526+
15241527 async def process_api (
15251528 self ,
15261529 fn_index : int ,
@@ -1565,10 +1568,14 @@ async def process_api(
15651568 raise ValueError (
15661569 f"Batch size ({ batch_size } ) exceeds the max_batch_size for this function ({ max_batch_size } )"
15671570 )
1568-
1569- inputs = [
1570- self .preprocess_data (fn_index , list (i ), state ) for i in zip (* inputs )
1571- ]
1571+ inputs = await anyio .to_thread .run_sync (
1572+ self .run_fn_batch ,
1573+ self .preprocess_data ,
1574+ inputs ,
1575+ fn_index ,
1576+ state ,
1577+ limiter = self .limiter ,
1578+ )
15721579 result = await self .call_function (
15731580 fn_index ,
15741581 list (zip (* inputs )),
@@ -1579,17 +1586,24 @@ async def process_api(
15791586 in_event_listener ,
15801587 )
15811588 preds = result ["prediction" ]
1582- data = [
1583- self .postprocess_data (fn_index , list (o ), state ) for o in zip (* preds )
1584- ]
1589+ data = await anyio .to_thread .run_sync (
1590+ self .run_fn_batch ,
1591+ self .postprocess_data ,
1592+ preds ,
1593+ fn_index ,
1594+ state ,
1595+ limiter = self .limiter ,
1596+ )
15851597 data = list (zip (* data ))
15861598 is_generating , iterator = None , None
15871599 else :
15881600 old_iterator = iterator
15891601 if old_iterator :
15901602 inputs = []
15911603 else :
1592- inputs = self .preprocess_data (fn_index , inputs , state )
1604+ inputs = await anyio .to_thread .run_sync (
1605+ self .preprocess_data , fn_index , inputs , state , limiter = self .limiter
1606+ )
15931607 was_generating = old_iterator is not None
15941608 result = await self .call_function (
15951609 fn_index ,
@@ -1600,7 +1614,13 @@ async def process_api(
16001614 event_data ,
16011615 in_event_listener ,
16021616 )
1603- data = self .postprocess_data (fn_index , result ["prediction" ], state )
1617+ data = await anyio .to_thread .run_sync (
1618+ self .postprocess_data ,
1619+ fn_index , # type: ignore
1620+ result ["prediction" ],
1621+ state ,
1622+ limiter = self .limiter ,
1623+ )
16041624 is_generating , iterator = result ["is_generating" ], result ["iterator" ]
16051625 if is_generating or was_generating :
16061626 run = id (old_iterator ) if was_generating else id (iterator )
0 commit comments