Skip to content

Commit 6d46ddb

Browse files
feat: cache NodeJS proxied requests (#12558)
* NodeProxyCache * add changeset * Last touch-ups * add changeset * Move NodeProxyCache to route_utils + last touch-ups * Fix race condition * Actually safer * Immutable: no need dataclass --------- Co-authored-by: gradio-pr-bot <[email protected]>
1 parent f1d83fa commit 6d46ddb

File tree

3 files changed

+85
-9
lines changed

3 files changed

+85
-9
lines changed

.changeset/floppy-dingos-act.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
---
2+
"gradio": patch
3+
---
4+
5+
fix:feat: cache NodeJS proxied requests

gradio/route_utils.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
TYPE_CHECKING,
2424
Any,
2525
BinaryIO,
26+
NamedTuple,
2627
Union,
2728
)
2829
from urllib.parse import urlparse
@@ -1057,3 +1058,71 @@ def slugify(value):
10571058
)
10581059
value = re.sub(r"[^\w\s-]", "", value.lower())
10591060
return re.sub(r"[-\s]+", "-", value).strip("-_")
1061+
1062+
1063+
class NodeProxyCache:
1064+
"""
1065+
Fan-out streaming cache for NodeJS requests proxying
1066+
"""
1067+
1068+
class CacheEntry(NamedTuple):
1069+
head: bytearray
1070+
subs: list[asyncio.Queue[bytes | None]]
1071+
resp: asyncio.Future[httpx.Response | None]
1072+
1073+
class ProxyReq(NamedTuple):
1074+
method: str
1075+
url: str
1076+
headers: dict[str, str]
1077+
1078+
def __init__(self, client: httpx.AsyncClient):
1079+
self.client = client
1080+
self.cache: dict[str, NodeProxyCache.CacheEntry | None] = {}
1081+
1082+
async def get(self, req: ProxyReq):
1083+
key = f"{req.method} {req.url}"
1084+
key += "::".join(map(":".join, req.headers.items()))
1085+
res: asyncio.Queue[bytes | None] = asyncio.Queue()
1086+
if (entry := self.cache.get(key, None)) is None:
1087+
loop = asyncio.get_running_loop()
1088+
entry = NodeProxyCache.CacheEntry(bytearray(), [], loop.create_future())
1089+
asyncio.create_task(self.fetch(key, entry, req))
1090+
self.cache[key] = entry
1091+
entry.subs.append(res)
1092+
head = bytes(entry.head)
1093+
if (resp := await entry.resp) is None:
1094+
raise Error("Error while proxying request to Node server")
1095+
return resp.status_code, resp.headers, NodeProxyCache.iter_body(head, res)
1096+
1097+
async def fetch(self, key: str, entry: CacheEntry, req: ProxyReq):
1098+
try:
1099+
response = await self.client.send(
1100+
self.client.build_request(
1101+
method=req.method,
1102+
url=httpx.URL(req.url),
1103+
headers=req.headers,
1104+
),
1105+
stream=True,
1106+
)
1107+
except Exception:
1108+
entry.resp.set_result(None)
1109+
del self.cache[key]
1110+
raise
1111+
entry.resp.set_result(response)
1112+
try:
1113+
async for bytes_chunk in response.aiter_raw():
1114+
entry.head.extend(bytes_chunk)
1115+
for sub in entry.subs:
1116+
sub.put_nowait(bytes_chunk)
1117+
finally:
1118+
for sub in entry.subs:
1119+
sub.put_nowait(None)
1120+
del self.cache[key]
1121+
await response.aclose()
1122+
1123+
@staticmethod
1124+
async def iter_body(head: bytes, queue: asyncio.Queue[bytes | None]):
1125+
if len(head) > 0:
1126+
yield head
1127+
while (chunk := await queue.get()) is not None:
1128+
yield chunk

gradio/routes.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@
102102
GradioMultiPartParser,
103103
GradioUploadFile,
104104
MultiPartException,
105+
NodeProxyCache,
105106
Request,
106107
compare_passwords_securely,
107108
create_lifespan_handler,
@@ -281,6 +282,7 @@ def __init__(
281282
# We're not overriding any defaults here
282283

283284
client = httpx.AsyncClient()
285+
proxy_cache = NodeProxyCache(client)
284286

285287
@staticmethod
286288
async def proxy_to_node(
@@ -305,23 +307,23 @@ async def proxy_to_node(
305307
if mounted_path:
306308
server_url += mounted_path
307309

308-
headers = dict(request.headers)
310+
headers = {} # Do not include arbitrary headers from original request so NodeProxyCache can be effective
309311
headers["x-gradio-server"] = server_url
310312
headers["x-gradio-port"] = str(python_port)
311313

312314
if os.getenv("GRADIO_LOCAL_DEV_MODE"):
313315
headers["x-gradio-local-dev-mode"] = "1"
314316

315-
new_request = App.client.build_request(
316-
request.method, httpx.URL(url), headers=headers
317-
)
318-
node_response = await App.client.send(new_request, stream=True)
317+
if (accept_language := request.headers.get("accept-language")) is not None:
318+
headers["accept-language"] = accept_language
319+
320+
proxy_req = App.proxy_cache.ProxyReq(request.method, url, headers)
321+
status, response_headers, aiter_raw = await App.proxy_cache.get(proxy_req)
319322

320323
return StreamingResponse(
321-
node_response.aiter_raw(),
322-
status_code=node_response.status_code,
323-
headers=node_response.headers,
324-
background=BackgroundTask(node_response.aclose),
324+
aiter_raw,
325+
status_code=status,
326+
headers=response_headers,
325327
)
326328

327329
def configure_app(self, blocks: gradio.Blocks) -> None:

0 commit comments

Comments
 (0)