Skip to content

Commit fb1f6be

Browse files
Run pre/post processing in threadpool (#7327)
* Add code * Add code * Add code * add changeset --------- Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
1 parent 7b84bc4 commit fb1f6be

4 files changed

Lines changed: 43 additions & 13 deletions

File tree

.changeset/hot-taxis-jump.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
---
2+
"gradio": patch
3+
---
4+
5+
fix:Run pre/post processing in threadpool

gradio/blocks.py

Lines changed: 29 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1521,6 +1521,9 @@ def handle_streaming_diffs(
15211521

15221522
return data
15231523

1524+
def run_fn_batch(self, fn, batch, fn_index, state):
1525+
return [fn(fn_index, list(i), state) for i in zip(*batch)]
1526+
15241527
async def process_api(
15251528
self,
15261529
fn_index: int,
@@ -1565,10 +1568,14 @@ async def process_api(
15651568
raise ValueError(
15661569
f"Batch size ({batch_size}) exceeds the max_batch_size for this function ({max_batch_size})"
15671570
)
1568-
1569-
inputs = [
1570-
self.preprocess_data(fn_index, list(i), state) for i in zip(*inputs)
1571-
]
1571+
inputs = await anyio.to_thread.run_sync(
1572+
self.run_fn_batch,
1573+
self.preprocess_data,
1574+
inputs,
1575+
fn_index,
1576+
state,
1577+
limiter=self.limiter,
1578+
)
15721579
result = await self.call_function(
15731580
fn_index,
15741581
list(zip(*inputs)),
@@ -1579,17 +1586,24 @@ async def process_api(
15791586
in_event_listener,
15801587
)
15811588
preds = result["prediction"]
1582-
data = [
1583-
self.postprocess_data(fn_index, list(o), state) for o in zip(*preds)
1584-
]
1589+
data = await anyio.to_thread.run_sync(
1590+
self.run_fn_batch,
1591+
self.postprocess_data,
1592+
preds,
1593+
fn_index,
1594+
state,
1595+
limiter=self.limiter,
1596+
)
15851597
data = list(zip(*data))
15861598
is_generating, iterator = None, None
15871599
else:
15881600
old_iterator = iterator
15891601
if old_iterator:
15901602
inputs = []
15911603
else:
1592-
inputs = self.preprocess_data(fn_index, inputs, state)
1604+
inputs = await anyio.to_thread.run_sync(
1605+
self.preprocess_data, fn_index, inputs, state, limiter=self.limiter
1606+
)
15931607
was_generating = old_iterator is not None
15941608
result = await self.call_function(
15951609
fn_index,
@@ -1600,7 +1614,13 @@ async def process_api(
16001614
event_data,
16011615
in_event_listener,
16021616
)
1603-
data = self.postprocess_data(fn_index, result["prediction"], state)
1617+
data = await anyio.to_thread.run_sync(
1618+
self.postprocess_data,
1619+
fn_index, # type: ignore
1620+
result["prediction"],
1621+
state,
1622+
limiter=self.limiter,
1623+
)
16041624
is_generating, iterator = result["is_generating"], result["iterator"]
16051625
if is_generating or was_generating:
16061626
run = id(old_iterator) if was_generating else id(iterator)

gradio/components/gallery.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from __future__ import annotations
44

5+
from concurrent.futures import ThreadPoolExecutor
56
from pathlib import Path
67
from typing import Any, Callable, List, Literal, Optional, Tuple, Union
78
from urllib.parse import urlparse
@@ -165,7 +166,8 @@ def postprocess(
165166
if value is None:
166167
return GalleryData(root=[])
167168
output = []
168-
for img in value:
169+
170+
def _save(img):
169171
url = None
170172
caption = None
171173
orig_name = None
@@ -194,11 +196,14 @@ def postprocess(
194196
orig_name = img.name
195197
else:
196198
raise ValueError(f"Cannot process type as image: {type(img)}")
197-
entry = GalleryImage(
199+
return GalleryImage(
198200
image=FileData(path=file_path, url=url, orig_name=orig_name),
199201
caption=caption,
200202
)
201-
output.append(entry)
203+
204+
with ThreadPoolExecutor() as executor:
205+
for o in executor.map(_save, value):
206+
output.append(o)
202207
return GalleryData(root=output)
203208

204209
@staticmethod

gradio/processing_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ def save_pil_to_cache(
135135
temp_dir = Path(cache_dir) / hash_bytes(bytes_data)
136136
temp_dir.mkdir(exist_ok=True, parents=True)
137137
filename = str((temp_dir / f"{name}.{format}").resolve())
138-
img.save(filename, pnginfo=get_pil_metadata(img))
138+
(temp_dir / f"{name}.{format}").resolve().write_bytes(bytes_data)
139139
return filename
140140

141141

0 commit comments

Comments
 (0)