Skip to content

alexandrehtrb/AlexandreHtrb.WebSocketExtensions

Repository files navigation

AlexandreHtrb.WebSocketExtensions

Ler em português

This project is a custom abstraction layer built on top of System.Net.WebSockets standard implementations, to handle WebSocket lifecycle, parse and convert messages from / to byte arrays. The abstractions are for both client-side and server-side (ASP.NET).

It has full compatibility with NativeAOT and trimming.

Installation

Add the NuGet package to the project file:

<ItemGroup>
    <PackageReference Include="AlexandreHtrb.WebSocketExtensions" Version="1.0.2" />
</ItemGroup>

How to use

It's quite simple to use. The WebSocketServerSideConnector and WebSocketClientSideConnector classes receive native WebSocket objects from .NET and take care of all WebSocket's lifecycle, connection and disconnection, receiving and sending messages, and conversion from/to byte arrays.

Inside each connector there is an ExchangedMessagesCollector, which collects the messages sent and received, and makes them available through an IAsyncEnumerable. Thus, the conversation between client and server goes inside an asynchronous await foreach loop. When one of the parties disconnects, the execution leaves the loop.

Before entering the conversation loop, one of the parties must take the initiative to send a message.

Example code, server

WebSocketServerSideConnector wsc = new(ws, collectOnlyClientSideMessages: true);

await foreach (var msg in wsc.ExchangedMessagesCollector!.ReadAllAsync())
{
    await wsc.SendMessageAsync(WebSocketMessageType.Text, msg.ReadAsUtf8Text() switch
    {
        "Hello!" => "Hi!",
        "What time is it?" => "Now it's " + DateTime.Now.TimeOfDay,
        "Thanks!" => "You're welcome!",
        _ => "I don't understand your message!"
    }, false);
}

Example code, client

using var cws = MakeClientWebSocket();
using var hc = MakeHttpClient(disableSslVerification: true);
Uri uri = new("ws://localhost:5000/test/http1websocket");
WebSocketClientSideConnector wsc = new(collectOnlyServerSideMessages: true);

// Connecting
await wsc.ConnectAsync(cws, hc, uri, cancellationToken);

// Sending first message
await wsc.SendMessageAsync(WebSocketMessageType.Text, "Hello!", false);

// Conversation loop
await foreach (var msg in wsc.ExchangedMessagesCollector!.ReadAllAsync())
{
    string? replyTxt = msg.ReadAsUtf8Text() switch
    {
        "Hi!" => "What time is it?",
        string s when s.StartsWith("Now it's") => "Thanks!",
        _ => null
    };
    
    if (replyTxt != null)
        await wsc.SendMessageAsync(WebSocketMessageType.Text, replyTxt, false);
}

With the code above, the conversation should go on like this:

Client: Hello!
Server: Hi!
Client: What time is it?
Server: Now it's 11:54:53
Client: Thanks!
Server: You're welcome!

WebSockets configuration on ASP.NET

  1. Enable app.UseWebSockets():
private static IApplicationBuilder ConfigureApp(this WebApplication app) =>
    app.MapTestEndpoints()
       .UseWebSockets(new()
       {
           KeepAliveInterval = TimeSpan.FromMinutes(2)
       });
  1. Map the WebSocket endpoint:
public static WebApplication MapTestEndpoints(this WebApplication app)
{
    app.MapGet("test/http1websocket", TestHttp1WebSocket);
    return app;
}

private static async Task TestHttp1WebSocket(HttpContext httpCtx, ILogger<BackgroundWebSocketsProcessor> logger)
{
    if (!httpCtx.WebSockets.IsWebSocketRequest)
    {
        byte[] txtBytes = Encoding.UTF8.GetBytes("Only WebSockets requests are accepted here!");
        httpCtx.Response.StatusCode = (int)HttpStatusCode.BadRequest;
        await httpCtx.Response.BodyWriter.WriteAsync(txtBytes);
    }
    else
    {
        using var webSocket = await httpCtx.WebSockets.AcceptWebSocketAsync();
        TaskCompletionSource<object> socketFinishedTcs = new();

        await BackgroundWebSocketsProcessor.RegisterAndProcessAsync(logger, webSocket, socketFinishedTcs);
        await socketFinishedTcs.Task;
    }
}
  1. Create a WebSocket processor:
using AlexandreHtrb.WebSocketExtensions;
using System.Net.WebSockets;

public static class BackgroundWebSocketsProcessor
{
    public static async Task RegisterAndProcessAsync(ILogger<BackgroundWebSocketsProcessor> logger, WebSocket ws, TaskCompletionSource<object> socketFinishedTcs)
    {
        WebSocketServerSideConnector wsc = new(ws, collectOnlyClientSideMessages: true);

        int msgCount = 0;
        await foreach (var msg in wsc.ExchangedMessagesCollector!.ReadAllAsync())
        {
            msgCount++;
            string msgText = msg.Type switch
            {
                WebSocketMessageType.Text or WebSocketMessageType.Close => msg.ReadAsUtf8Text()!,
                WebSocketMessageType.Binary when msg.BytesStream is MemoryStream ms => $"(binary, {ms.Length} bytes)",
                WebSocketMessageType.Binary when msg.BytesStream is not MemoryStream => $"(binary, ? bytes)",
                _ => "(unknown)"
            };
            logger.LogInformation("Message {msgCount}, {direction}: {msgText}", msgCount, msg.Direction, msgText);
                
            // handle messages here
        }
        
        socketFinishedTcs.SetResult(true); // finish connection
    }
}

Tips and tricks

Monitor connection state

wsc.OnConnectionChanged = (state, exception) =>
{
    logger.LogInformation("Connection state: {state}", state);
    logger.LogError(exception, "Connection exception");
};

Here we can put connection retries.

Collect sent messages

WebSocketServerSideConnector wsc = new(ws, collectOnlyClientSideMessages: false);

WebSocketClientSideConnector wsc = new(collectOnlyServerSideMessages: false);

The booleans control whether only messages from the opposite side will be collected. Collecting messages from the own side may be interesting for logging.

Periodically send a message

while (!cancellationToken.IsCancellationRequested)
{
    _ = Task.Run(async () =>
    {
        await Task.Delay(TimeSpan.FromSeconds(15));
        await wsc.SendMessageAsync(WebSocketMessageType.Text, "Are you there?", false);
    });
}

End conversation after a certain amount of time

_ = Task.Run(async () =>
{
    await Task.Delay(maximumLifetimePeriod);
    await wsc.DisconnectAsync();
});

Sending files

// don't use 'using'
FileStream fs = new("C:\\Files\my_image.jpg", FileMode.Open);
await wsc.SendMessageAsync(WebSocketMessageType.Binary, fs, false);

When using Streams to send messages, don't use the using keyword. The Stream will be disposed further, inside the connector.

Retrieve HTTP status code and response headers

ClientWebSocket cws = new();
cws.Options.CollectHttpResponseDetails = true;

await wsc.ConnectAsync(cws, hc, uri, cancellationToken);

var wsHttpStatusCode = wsc.ConnectionHttpStatusCode;
var wsResponseHeaders = wsc.ConnectionHttpHeaders;

Authentication and request headers

ClientWebSocket cws = new();
cws.Options.SetRequestHeader("Authorization", "Bearer my_token");
cws.Options.SetRequestHeader("Header1", "Value1");

await wsc.ConnectAsync(cws, hc, uri, cancellationToken);

Subprotocols

Client-side

ClientWebSocket cws = new();
cws.Options.AddSubProtocol("subprotocol1");

await wsc.ConnectAsync(cws, hc, uri, cancellationToken);

Server-side

private static async Task TestHttp1WebSocket(HttpContext httpCtx, ILogger<BackgroundWebSocketsProcessor> logger)
{
    if (!httpCtx.WebSockets.IsWebSocketRequest)
    {
        // ...
    }
    else
    {
        using var webSocket = await httpCtx.WebSockets.AcceptWebSocketAsync();
        TaskCompletionSource<object> socketFinishedTcs = new();
+        string? subprotocol = webSocket.SubProtocol ?? httpCtx.WebSockets.WebSocketRequestedProtocols.FirstOrDefault();

        await BackgroundWebSocketsProcessor.RegisterAndProcessAsync(logger, webSocket, subprotocol, socketFinishedTcs);
        await socketFinishedTcs.Task;
    }
}

Message compression

ClientWebSocket cws = new();
cws.Options.DangerousDeflateOptions = new()
{
    ClientContextTakeover = true,
    ClientMaxWindowBits = 14,
    ServerContextTakeover = true,
    ServerMaxWindowBits = 14
};

await wsc.ConnectAsync(cws, hc, uri, cancellationToken);

Important: Don't pass secrets and encrypted texts in compressed messages, because there is the risk of BREACH and CRIME attacks. In these cases, disable compression for those messages:

await wsc.SendMessageAsync(
    WebSocketMessageType.Text,
    $"Encrypted token {token}",
    disableCompression: true);

WebSockets over HTTP/2

Client-side

ClientWebSocket cws = new();
cws.Options.HttpVersionPolicy = HttpVersionPolicy.RequestVersionExact;
cws.Options.HttpVersion = new(2,0);

await wsc.ConnectAsync(cws, hc, uri, cancellationToken);

Server-side

On HTTP/2 WebSockets, the HTTP method CONNECT is used, instead of GET.

public static WebApplication MapTestEndpoints(this WebApplication app)
{
    app.MapMethods("test/http2websocket", new[] { HttpMethods.Connect }, TestHttp2WebSocket);
    return app;
}