Skip to content

Add support for websocket proxy #701

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 44 additions & 0 deletions examples/nginx.conf
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
daemon off;

events {
worker_connections 1024;
}

http {
access_log /dev/stdout;
error_log /dev/stderr;

upstream backend {
server 127.0.0.1:8080; # Your WebSocket server
}

# Map to handle WebSocket upgrade
map $http_upgrade $connection_upgrade {
default upgrade;
'' close;
}

server {
listen 8888;
server_name localhost;

# Handle both HTTP and WebSocket at root
location / {
proxy_pass http://backend;
proxy_http_version 1.1;

# WebSocket headers
proxy_set_header Upgrade $http_upgrade;
proxy_set_header Connection $connection_upgrade;

# Standard proxy headers
proxy_set_header Host $host;
proxy_set_header X-Real-IP $remote_addr;
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
proxy_set_header X-Forwarded-Proto $scheme;

# Optional: Increase timeouts for long-lived WebSocket connections
proxy_read_timeout 86400;
}
}
}
80 changes: 80 additions & 0 deletions examples/websocket_proxy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
#!/usr/bin/env python3

"""
NATS WebSocket proxy example.

Usage:
python websocket_proxy_example.py [--proxy http://localhost:8888] [--proxy-user user] [--proxy-password pass]
"""

import asyncio
import argparse
import nats


async def main():
"""Connect to NATS WebSocket server with optional proxy"""

# Parse command line arguments
parser = argparse.ArgumentParser(description="NATS WebSocket Proxy Example")
parser.add_argument("--server", default="ws://localhost:8080",
help="NATS WebSocket server URL (default: ws://localhost:8080)")
parser.add_argument("--proxy", help="HTTP proxy URL (e.g., http://localhost:8888)")
parser.add_argument("--proxy-user", help="Proxy username for authentication")
parser.add_argument("--proxy-password", help="Proxy password for authentication")

args = parser.parse_args()

# Build connection options
connect_options = {
"servers": [args.server]
}

if args.proxy:
connect_options["proxy"] = args.proxy
print(f"Proxy: {args.proxy}")

if args.proxy_user and args.proxy_password:
connect_options["proxy_user"] = args.proxy_user
connect_options["proxy_password"] = args.proxy_password
print(f"Auth: {args.proxy_user}")
elif args.proxy_user or args.proxy_password:
print("Error: Both user and password required")
return 1

try:
print(f"Connecting to {args.server}...")
nc = await nats.connect(**connect_options)
print("Connected")

# Test pub/sub
messages = []

async def handler(msg):
messages.append(msg.data.decode())

await nc.subscribe("test", cb=handler)
await nc.flush()

# Send test message
for i in range(10):
await nc.publish("test", f"hello {i}".encode())
await nc.flush()
await asyncio.sleep(0.1)

expected = [f"hello {i}" for i in range(10)]
assert messages == expected, f"Expected {expected}, got {messages}"
print(f"Success: All {len(messages)} messages received correctly")

await nc.close()

except Exception as e:
print(f"Failed: {e}")
return 1

return 0


if __name__ == "__main__":
exit_code = asyncio.run(main())
exit(exit_code)
15 changes: 14 additions & 1 deletion nats/aio/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,9 @@ async def connect(
inbox_prefix: Union[str, bytes] = DEFAULT_INBOX_PREFIX,
pending_size: int = DEFAULT_PENDING_SIZE,
flush_timeout: Optional[float] = None,
proxy: Optional[str] = None,
proxy_user: Optional[str] = None,
proxy_password: Optional[str] = None,
) -> None:
"""
Establishes a connection to NATS.
Expand All @@ -370,6 +373,9 @@ async def connect(
:param discovered_server_cb: Callback to report when a new server joins the cluster.
:param pending_size: Max size of the pending buffer for publishing commands.
:param flush_timeout: Max duration to wait for a forced flush to occur.
:param proxy: Proxy URL for WebSocket connections (e.g., 'http://proxy.example.com:8080')
:param proxy_user: Username for proxy authentication
:param proxy_password: Password for proxy authentication

Connecting setting all callbacks::

Expand Down Expand Up @@ -495,6 +501,9 @@ async def subscribe_handler(msg):
self.options["connect_timeout"] = connect_timeout
self.options["drain_timeout"] = drain_timeout
self.options["tls_handshake_first"] = tls_handshake_first
self.options["proxy"] = proxy
self.options["proxy_user"] = proxy_user
self.options["proxy_password"] = proxy_password

if tls:
self.options["tls"] = tls
Expand Down Expand Up @@ -1380,7 +1389,11 @@ async def _select_next_server(self) -> None:
s.last_attempt = time.monotonic()
if not self._transport:
if s.uri.scheme in ("ws", "wss"):
self._transport = WebSocketTransport()
self._transport = WebSocketTransport(
proxy=self.options.get("proxy"),
proxy_user=self.options.get("proxy_user"),
proxy_password=self.options.get("proxy_password"),
)
else:
# use TcpTransport as a fallback
self._transport = TcpTransport()
Expand Down
33 changes: 26 additions & 7 deletions nats/aio/transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@

try:
import aiohttp
from aiohttp import BasicAuth
except ImportError:
aiohttp = None # type: ignore[assignment]
BasicAuth = None # type: ignore[assignment]

from nats.errors import ProtocolError

Expand Down Expand Up @@ -192,12 +194,21 @@ def __bool__(self):

class WebSocketTransport(Transport):

def __init__(self):
def __init__(
self,
proxy: Optional[str] = None,
proxy_user: Optional[str] = None,
proxy_password: Optional[str] = None
):
if not aiohttp:
raise ImportError(
"Could not import aiohttp transport, please install it with `pip install aiohttp`"
)
self._ws: Optional[aiohttp.ClientWebSocketResponse] = None
self._proxy = proxy
self._proxy_auth = None
if proxy_user and proxy_password:
self._proxy_auth = BasicAuth(proxy_user, proxy_password)
self._client: aiohttp.ClientSession = aiohttp.ClientSession()
self._pending = asyncio.Queue()
self._close_task = asyncio.Future()
Expand All @@ -207,9 +218,13 @@ async def connect(
self, uri: ParseResult, buffer_size: int, connect_timeout: int
):
# for websocket library, the uri must contain the scheme already
self._ws = await self._client.ws_connect(
uri.geturl(), timeout=connect_timeout
)
kwargs = {"timeout": connect_timeout}
if self._proxy:
kwargs["proxy"] = self._proxy
if self._proxy_auth:
kwargs["proxy_auth"] = self._proxy_auth

self._ws = await self._client.ws_connect(uri.geturl(), **kwargs)
self._using_tls = False

async def connect_tls(
Expand All @@ -224,10 +239,14 @@ async def connect_tls(
return
raise ProtocolError("ws: cannot upgrade to TLS")

kwargs = {"ssl": ssl_context, "timeout": connect_timeout}
if self._proxy:
kwargs["proxy"] = self._proxy
if self._proxy_auth:
kwargs["proxy_auth"] = self._proxy_auth

self._ws = await self._client.ws_connect(
uri if isinstance(uri, str) else uri.geturl(),
ssl=ssl_context,
timeout=connect_timeout,
uri if isinstance(uri, str) else uri.geturl(), **kwargs
)
self._using_tls = True

Expand Down
Loading