diff --git a/src/Components/Endpoints/src/Microsoft.AspNetCore.Components.Endpoints.csproj b/src/Components/Endpoints/src/Microsoft.AspNetCore.Components.Endpoints.csproj index 4fa9814ea77e..aa481325046c 100644 --- a/src/Components/Endpoints/src/Microsoft.AspNetCore.Components.Endpoints.csproj +++ b/src/Components/Endpoints/src/Microsoft.AspNetCore.Components.Endpoints.csproj @@ -23,6 +23,8 @@ + + diff --git a/src/Components/Server/src/CircuitOptions.cs b/src/Components/Server/src/CircuitOptions.cs index ec5443f01c9d..7c63f0ef361e 100644 --- a/src/Components/Server/src/CircuitOptions.cs +++ b/src/Components/Server/src/CircuitOptions.cs @@ -44,6 +44,22 @@ public sealed class CircuitOptions /// public TimeSpan DisconnectedCircuitRetentionPeriod { get; set; } = TimeSpan.FromMinutes(3); + /// + /// Gets or sets a value that determines the maximum number of persisted circuits state that + /// are retained in memory by the server when no distributed cache is configured. + /// + /// + /// When using a distributed cache like this value is ignored + /// and the configuration from + /// is used instead. + /// + public int PersistedCircuitInMemoryMaxRetained { get; set; } = 1000; + + /// + /// Gets or sets the duration for which a persisted circuit is retained in memory. + /// + public TimeSpan PersistedCircuitInMemoryRetentionPeriod { get; set; } = TimeSpan.FromHours(2); + /// /// Gets or sets a value that determines whether or not to send detailed exception messages to JavaScript when an unhandled exception /// happens on the circuit or when a .NET method invocation through JS interop results in an exception. diff --git a/src/Components/Server/src/Circuits/CircuitHost.cs b/src/Components/Server/src/Circuits/CircuitHost.cs index 38b50461ce3e..3f47d11a8f3a 100644 --- a/src/Components/Server/src/Circuits/CircuitHost.cs +++ b/src/Components/Server/src/Circuits/CircuitHost.cs @@ -32,6 +32,7 @@ internal partial class CircuitHost : IAsyncDisposable private bool _isFirstUpdate = true; private bool _disposed; private long _startTime; + private PersistedCircuitState _persistedCircuitState; // This event is fired when there's an unrecoverable exception coming from the circuit, and // it need so be torn down. The registry listens to this even so that the circuit can @@ -106,6 +107,8 @@ public CircuitHost( public IServiceProvider Services { get; } + internal bool HasPendingPersistedCircuitState => _persistedCircuitState != null; + // InitializeAsync is used in a fire-and-forget context, so it's responsible for its own // error handling. public Task InitializeAsync(ProtectedPrerenderComponentApplicationStore store, ActivityContext httpContext, CancellationToken cancellationToken) @@ -873,6 +876,23 @@ await HandleInboundActivityAsync(() => } } + internal void AttachPersistedState(PersistedCircuitState persistedCircuitState) + { + if (_persistedCircuitState != null) + { + throw new InvalidOperationException("Persisted state has already been attached to this circuit."); + } + + _persistedCircuitState = persistedCircuitState; + } + + internal PersistedCircuitState TakePersistedCircuitState() + { + var result = _persistedCircuitState; + _persistedCircuitState = null; + return result; + } + private static partial class Log { // 100s used for lifecycle stuff diff --git a/src/Components/Server/src/Circuits/CircuitPersistenceManager.cs b/src/Components/Server/src/Circuits/CircuitPersistenceManager.cs new file mode 100644 index 000000000000..df1d0c7f73ca --- /dev/null +++ b/src/Components/Server/src/Circuits/CircuitPersistenceManager.cs @@ -0,0 +1,172 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Text.Json; +using System.Text.Json.Serialization; +using Microsoft.AspNetCore.Components.Endpoints; +using Microsoft.AspNetCore.Components.Infrastructure; +using Microsoft.AspNetCore.Components.Web; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Options; + +namespace Microsoft.AspNetCore.Components.Server.Circuits; + +internal partial class CircuitPersistenceManager( + IOptions circuitOptions, + ServerComponentSerializer serverComponentSerializer, + ICircuitPersistenceProvider circuitPersistenceProvider) +{ + public async Task PauseCircuitAsync(CircuitHost circuit, CancellationToken cancellation = default) + { + var renderer = circuit.Renderer; + var persistenceManager = circuit.Services.GetRequiredService(); + var collector = new CircuitPersistenceManagerCollector(circuitOptions, serverComponentSerializer, circuit.Renderer); + using var subscription = persistenceManager.State.RegisterOnPersisting( + collector.PersistRootComponents, + RenderMode.InteractiveServer); + + await persistenceManager.PersistStateAsync(collector, renderer); + + await circuitPersistenceProvider.PersistCircuitAsync( + circuit.CircuitId, + collector.PersistedCircuitState, + cancellation); + } + + public async Task ResumeCircuitAsync(CircuitId circuitId, CancellationToken cancellation = default) + { + return await circuitPersistenceProvider.RestoreCircuitAsync(circuitId, cancellation); + } + + // We are going to construct a RootComponentOperationBatch but we are going to replace the descriptors from the client with the + // descriptors that we have persisted when pausing the circuit. + // The way pausing and resuming works is that when the client starts the resume process, it 'simulates' that an SSR has happened and + // queues an 'Add' operation for each server-side component that is on the document. + // That ends up calling UpdateRootComponents with the old descriptors and no application state. + // On the server side, we replace the descriptors with the ones that we have persisted. We can't use the original descriptors because + // those have a lifetime of ~ 5 minutes, after which we are not able to unprotect them anymore. + internal static RootComponentOperationBatch ToRootComponentOperationBatch( + IServerComponentDeserializer serverComponentDeserializer, + byte[] rootComponents, + string serializedComponentOperations) + { + // Deserialize the existing batch the client has sent but ignore the markers + if (!serverComponentDeserializer.TryDeserializeRootComponentOperations( + serializedComponentOperations, + out var batch, + deserializeDescriptors: false)) + { + return null; + } + + var persistedMarkers = TryDeserializeMarkers(rootComponents); + + if (persistedMarkers == null) + { + return null; + } + + if (batch.Operations.Length != persistedMarkers.Count) + { + return null; + } + + // Ensure that all operations in the batch are `Add` operations. + for (var i = 0; i < batch.Operations.Length; i++) + { + var operation = batch.Operations[i]; + if (operation.Type != RootComponentOperationType.Add) + { + return null; + } + + // Retrieve the marker from the persisted root components, replace it and deserialize the descriptor + if (!persistedMarkers.TryGetValue(operation.SsrComponentId, out var marker)) + { + return null; + } + operation.Marker = marker; + + if (!serverComponentDeserializer.TryDeserializeWebRootComponentDescriptor(operation.Marker.Value, out var descriptor)) + { + return null; + } + + operation.Descriptor = descriptor; + } + + return batch; + + static Dictionary TryDeserializeMarkers(byte[] rootComponents) + { + if (rootComponents == null || rootComponents.Length == 0) + { + return null; + } + + try + { + return JsonSerializer.Deserialize>( + rootComponents, + JsonSerializerOptionsProvider.Options); + } + catch + { + return null; + } + } + } + + private class CircuitPersistenceManagerCollector( + IOptions circuitOptions, + ServerComponentSerializer serverComponentSerializer, + RemoteRenderer renderer) + : IPersistentComponentStateStore + { + internal PersistedCircuitState PersistedCircuitState { get; private set; } + + public Task PersistRootComponents() + { + var persistedComponents = new Dictionary(); + var components = renderer.GetOrCreateWebRootComponentManager().GetRootComponents(); + var invocation = new ServerComponentInvocationSequence(); + foreach (var (id, componentKey, (componentType, parameters)) in components) + { + var distributedRetention = circuitOptions.Value.PersistedCircuitInMemoryRetentionPeriod; + var localRetention = circuitOptions.Value.PersistedCircuitInMemoryRetentionPeriod; + var maxRetention = distributedRetention > localRetention ? distributedRetention : localRetention; + + var marker = ComponentMarker.Create(ComponentMarker.ServerMarkerType, prerendered: false, componentKey); + serverComponentSerializer.SerializeInvocation(ref marker, invocation, componentType, parameters, maxRetention); + persistedComponents.Add(id, marker); + } + + PersistedCircuitState = new PersistedCircuitState + { + RootComponents = JsonSerializer.SerializeToUtf8Bytes( + persistedComponents, + CircuitPersistenceManagerSerializerContext.Default.DictionaryInt32ComponentMarker) + }; + + return Task.CompletedTask; + } + + // This store only support serializing the state + Task> IPersistentComponentStateStore.GetPersistedStateAsync() => throw new NotImplementedException(); + + // During the persisting phase the state is captured into a Dictionary, our implementation registers + // a callback so that it can run at the same time as the other components' state is persisted. + // We then are called to save the persisted state, at which point, we extract the component records + // and store them separately from the other state. + Task IPersistentComponentStateStore.PersistStateAsync(IReadOnlyDictionary state) + { + PersistedCircuitState.ApplicationState = state; + return Task.CompletedTask; + } + } + + [JsonSerializable(typeof(Dictionary))] + internal partial class CircuitPersistenceManagerSerializerContext : JsonSerializerContext + { + } +} diff --git a/src/Components/Server/src/Circuits/CircuitRegistry.cs b/src/Components/Server/src/Circuits/CircuitRegistry.cs index f686011da2a9..dcd5d8a3cccd 100644 --- a/src/Components/Server/src/Circuits/CircuitRegistry.cs +++ b/src/Components/Server/src/Circuits/CircuitRegistry.cs @@ -41,16 +41,19 @@ internal partial class CircuitRegistry private readonly CircuitOptions _options; private readonly ILogger _logger; private readonly CircuitIdFactory _circuitIdFactory; + private readonly CircuitPersistenceManager _circuitPersistenceManager; private readonly PostEvictionCallbackRegistration _postEvictionCallback; public CircuitRegistry( IOptions options, ILogger logger, - CircuitIdFactory CircuitHostFactory) + CircuitIdFactory CircuitHostFactory, + CircuitPersistenceManager circuitPersistenceManager) { _options = options.Value; _logger = logger; _circuitIdFactory = CircuitHostFactory; + _circuitPersistenceManager = circuitPersistenceManager; ConnectedCircuits = new ConcurrentDictionary(); DisconnectedCircuits = new MemoryCache(new MemoryCacheOptions @@ -265,7 +268,7 @@ protected virtual void OnEntryEvicted(object key, object value, EvictionReason r // Kick off the dispose in the background. var disconnectedEntry = (DisconnectedCircuitEntry)value; Log.CircuitEvicted(_logger, disconnectedEntry.CircuitHost.CircuitId, reason); - _ = DisposeCircuitEntry(disconnectedEntry); + _ = PauseAndDisposeCircuitEntry(disconnectedEntry); break; case EvictionReason.Removed: @@ -278,12 +281,23 @@ protected virtual void OnEntryEvicted(object key, object value, EvictionReason r } } - private async Task DisposeCircuitEntry(DisconnectedCircuitEntry entry) + private async Task PauseAndDisposeCircuitEntry(DisconnectedCircuitEntry entry) { DisposeTokenSource(entry); try { + if (!entry.CircuitHost.HasPendingPersistedCircuitState) + { + // Only pause and persist the circuit state if it has been active at some point, + // meaning that the client called UpdateRootComponents on it. + await _circuitPersistenceManager.PauseCircuitAsync(entry.CircuitHost); + } + else + { + Log.PersistedCircuitStateDiscarded(_logger, entry.CircuitHost.CircuitId); + } + entry.CircuitHost.UnhandledException -= CircuitHost_UnhandledException; await entry.CircuitHost.DisposeAsync(); } @@ -413,5 +427,8 @@ public static void ExceptionDisposingTokenSource(ILogger logger, Exception excep [LoggerMessage(115, LogLevel.Debug, "Reconnect to circuit with id {CircuitId} succeeded.", EventName = "ReconnectionSucceeded")] public static partial void ReconnectionSucceeded(ILogger logger, CircuitId circuitId); + + [LoggerMessage(116, LogLevel.Debug, "Circuit {CircuitId} was not resumed. Persisted circuit state for {CircuitId} discarded.", EventName = "PersistedCircuitStateDiscarded")] + public static partial void PersistedCircuitStateDiscarded(ILogger logger, CircuitId circuitId); } } diff --git a/src/Components/Server/src/Circuits/DefaultInMemoryCircuitPersistenceProvider.cs b/src/Components/Server/src/Circuits/DefaultInMemoryCircuitPersistenceProvider.cs new file mode 100644 index 000000000000..80679549690b --- /dev/null +++ b/src/Components/Server/src/Circuits/DefaultInMemoryCircuitPersistenceProvider.cs @@ -0,0 +1,170 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Diagnostics; +using Microsoft.Extensions.Caching.Memory; +using Microsoft.Extensions.Internal; +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Options; +using Microsoft.Extensions.Primitives; + +namespace Microsoft.AspNetCore.Components.Server.Circuits; + +// Default implmentation of ICircuitPersistenceProvider that uses an in-memory cache +internal sealed partial class DefaultInMemoryCircuitPersistenceProvider : ICircuitPersistenceProvider +{ + private readonly Lock _lock = new(); + private readonly CircuitOptions _options; + private readonly MemoryCache _persistedCircuits; + private static readonly Task _noMatch = Task.FromResult(null); + private readonly ILogger _logger; + + public PostEvictionCallbackRegistration PostEvictionCallback { get; internal set; } + + public DefaultInMemoryCircuitPersistenceProvider( + ISystemClock clock, + ILogger logger, + IOptions options) + { + _options = options.Value; + _persistedCircuits = new MemoryCache(new MemoryCacheOptions + { + SizeLimit = _options.PersistedCircuitInMemoryMaxRetained, + Clock = clock + }); + + PostEvictionCallback = new PostEvictionCallbackRegistration + { + EvictionCallback = OnEntryEvicted + }; + + _logger = logger; + } + + public Task PersistCircuitAsync(CircuitId circuitId, PersistedCircuitState persistedCircuitState, CancellationToken cancellation = default) + { + Log.CircuitPauseStarted(_logger, circuitId); + + lock (_lock) + { + PersistCore(circuitId, persistedCircuitState); + } + + return Task.CompletedTask; + } + + private void PersistCore(CircuitId circuitId, PersistedCircuitState persistedCircuitState) + { + var cancellationTokenSource = new CancellationTokenSource(_options.PersistedCircuitInMemoryRetentionPeriod); + var options = new MemoryCacheEntryOptions + { + Size = 1, + PostEvictionCallbacks = { PostEvictionCallback }, + ExpirationTokens = { new CancellationChangeToken(cancellationTokenSource.Token) }, + }; + + var persistedCircuitEntry = new PersistedCircuitEntry + { + State = persistedCircuitState, + TokenSource = cancellationTokenSource, + CircuitId = circuitId + }; + + _persistedCircuits.Set(circuitId.Secret, persistedCircuitEntry, options); + } + + private void OnEntryEvicted(object key, object value, EvictionReason reason, object state) + { + switch (reason) + { + case EvictionReason.Expired: + case EvictionReason.TokenExpired: + // Happens after the circuit state times out, this is triggered by the CancellationTokenSource we register + // with the entry, which is what controls the expiration + case EvictionReason.Capacity: + // Happens when the cache is full + var persistedCircuitEntry = (PersistedCircuitEntry)value; + Log.CircuitStateEvicted(_logger, persistedCircuitEntry.CircuitId, reason); + break; + + case EvictionReason.Removed: + // Happens when the entry is explicitly removed as part of resuming a circuit. + return; + default: + Debug.Fail($"Unexpected {nameof(EvictionReason)} {reason}"); + break; + } + } + + public Task RestoreCircuitAsync(CircuitId circuitId, CancellationToken cancellation = default) + { + Log.CircuitResumeStarted(_logger, circuitId); + + lock (_lock) + { + var state = RestoreCore(circuitId); + if (state == null) + { + Log.FailedToFindCircuitState(_logger, circuitId); + return _noMatch; + } + + return Task.FromResult(state); + } + } + + private PersistedCircuitState RestoreCore(CircuitId circuitId) + { + if (_persistedCircuits.TryGetValue(circuitId.Secret, out var value) && value is PersistedCircuitEntry entry) + { + DisposeTokenSource(entry); + _persistedCircuits.Remove(circuitId.Secret); + Log.CircuitStateFound(_logger, circuitId); + return entry.State; + } + + return null; + } + + private void DisposeTokenSource(PersistedCircuitEntry entry) + { + try + { + entry.TokenSource.Dispose(); + } + catch (Exception ex) + { + Log.ExceptionDisposingTokenSource(_logger, ex); + } + } + + private class PersistedCircuitEntry + { + public PersistedCircuitState State { get; set; } + + public CancellationTokenSource TokenSource { get; set; } + + public CircuitId CircuitId { get; set; } + } + + private static partial class Log + { + [LoggerMessage(101, LogLevel.Debug, "Circuit state evicted for circuit {CircuitId} due to {Reason}", EventName = "CircuitStateEvicted")] + public static partial void CircuitStateEvicted(ILogger logger, CircuitId circuitId, EvictionReason reason); + + [LoggerMessage(102, LogLevel.Debug, "Resuming circuit with ID {CircuitId}", EventName = "CircuitResumeStarted")] + public static partial void CircuitResumeStarted(ILogger logger, CircuitId circuitId); + + [LoggerMessage(103, LogLevel.Debug, "Failed to find persisted circuit with ID {CircuitId}", EventName = "FailedToFindCircuitState")] + public static partial void FailedToFindCircuitState(ILogger logger, CircuitId circuitId); + + [LoggerMessage(104, LogLevel.Debug, "Circuit state found for circuit {CircuitId}", EventName = "CircuitStateFound")] + public static partial void CircuitStateFound(ILogger logger, CircuitId circuitId); + + [LoggerMessage(105, LogLevel.Error, "An exception occurred while disposing the token source.", EventName = "ExceptionDisposingTokenSource")] + public static partial void ExceptionDisposingTokenSource(ILogger logger, Exception exception); + + [LoggerMessage(106, LogLevel.Debug, "Pausing circuit with ID {CircuitId}", EventName = "CircuitPauseStarted")] + public static partial void CircuitPauseStarted(ILogger logger, CircuitId circuitId); + } +} diff --git a/src/Components/Server/src/Circuits/ICircuitPersistenceProvider.cs b/src/Components/Server/src/Circuits/ICircuitPersistenceProvider.cs new file mode 100644 index 000000000000..3149886347da --- /dev/null +++ b/src/Components/Server/src/Circuits/ICircuitPersistenceProvider.cs @@ -0,0 +1,12 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +namespace Microsoft.AspNetCore.Components.Server.Circuits; + +// Abstraction to support persisting and restoring circuit state +internal interface ICircuitPersistenceProvider +{ + Task PersistCircuitAsync(CircuitId circuitId, PersistedCircuitState persistedCircuitState, CancellationToken cancellation = default); + + Task RestoreCircuitAsync(CircuitId circuitId, CancellationToken cancellation = default); +} diff --git a/src/Components/Server/src/Circuits/IServerComponentDeserializer.cs b/src/Components/Server/src/Circuits/IServerComponentDeserializer.cs index 524028a3f98a..b118a0cd6034 100644 --- a/src/Components/Server/src/Circuits/IServerComponentDeserializer.cs +++ b/src/Components/Server/src/Circuits/IServerComponentDeserializer.cs @@ -10,5 +10,7 @@ internal interface IServerComponentDeserializer bool TryDeserializeComponentDescriptorCollection( string serializedComponentRecords, out List descriptors); - bool TryDeserializeRootComponentOperations(string serializedComponentOperations, [NotNullWhen(true)] out RootComponentOperationBatch? operationBatch); + bool TryDeserializeRootComponentOperations(string serializedComponentOperations, [NotNullWhen(true)] out RootComponentOperationBatch? operationBatch, bool deserializeDescriptors = true); + + bool TryDeserializeWebRootComponentDescriptor(ComponentMarker record, [NotNullWhen(true)] out WebRootComponentDescriptor? result); } diff --git a/src/Components/Server/src/Circuits/PersistedCircuitState.cs b/src/Components/Server/src/Circuits/PersistedCircuitState.cs new file mode 100644 index 000000000000..81f1277d66ad --- /dev/null +++ b/src/Components/Server/src/Circuits/PersistedCircuitState.cs @@ -0,0 +1,19 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Diagnostics; + +namespace Microsoft.AspNetCore.Components.Server.Circuits; + +[DebuggerDisplay($"{{{nameof(GetDebuggerDisplay)}(),nq}}")] +internal class PersistedCircuitState +{ + public IReadOnlyDictionary ApplicationState { get; internal set; } + + public byte[] RootComponents { get; internal set; } + + private string GetDebuggerDisplay() + { + return $"ApplicationStateCount={ApplicationState?.Count ?? 0}, RootComponentsLength={RootComponents?.Length ?? 0} bytes"; + } +} diff --git a/src/Components/Server/src/Circuits/ServerComponentDeserializer.cs b/src/Components/Server/src/Circuits/ServerComponentDeserializer.cs index ca292c529da8..eaebd8856968 100644 --- a/src/Components/Server/src/Circuits/ServerComponentDeserializer.cs +++ b/src/Components/Server/src/Circuits/ServerComponentDeserializer.cs @@ -291,7 +291,10 @@ private bool TryDeserializeServerComponent(ComponentMarker record, out ServerCom return (componentDescriptor, serverComponent); } - public bool TryDeserializeRootComponentOperations(string serializedComponentOperations, [NotNullWhen(true)] out RootComponentOperationBatch? result) + public bool TryDeserializeRootComponentOperations( + string serializedComponentOperations, + [NotNullWhen(true)] out RootComponentOperationBatch? result, + bool deserializeMarkers = true) { int[]? seenComponentIdsStorage = null; try @@ -329,6 +332,13 @@ public bool TryDeserializeRootComponentOperations(string serializedComponentOper return false; } + if (!deserializeMarkers) + { + // If we are not deserializing markers, we can skip the rest of the processing. + operation.Descriptor = null; + continue; + } + if (!TryDeserializeWebRootComponentDescriptor(operation.Marker.Value, out var descriptor)) { result = null; diff --git a/src/Components/Server/src/ComponentHub.cs b/src/Components/Server/src/ComponentHub.cs index 84561349ee48..d91942c7711a 100644 --- a/src/Components/Server/src/ComponentHub.cs +++ b/src/Components/Server/src/ComponentHub.cs @@ -42,6 +42,7 @@ internal sealed partial class ComponentHub : Hub private readonly ICircuitFactory _circuitFactory; private readonly CircuitIdFactory _circuitIdFactory; private readonly CircuitRegistry _circuitRegistry; + private readonly CircuitPersistenceManager _circuitPersistenceManager; private readonly ICircuitHandleRegistry _circuitHandleRegistry; private readonly ILogger _logger; private readonly ActivityContext _httpContext; @@ -52,6 +53,7 @@ public ComponentHub( ICircuitFactory circuitFactory, CircuitIdFactory circuitIdFactory, CircuitRegistry circuitRegistry, + CircuitPersistenceManager circuitPersistenceProvider, ICircuitHandleRegistry circuitHandleRegistry, ILogger logger) { @@ -60,6 +62,7 @@ public ComponentHub( _circuitFactory = circuitFactory; _circuitIdFactory = circuitIdFactory; _circuitRegistry = circuitRegistry; + _circuitPersistenceManager = circuitPersistenceProvider; _circuitHandleRegistry = circuitHandleRegistry; _logger = logger; _httpContext = ComponentsActivitySource.CaptureHttpContext(); @@ -172,21 +175,36 @@ public async Task UpdateRootComponents(string serializedComponentOperations, str return; } - if (!_serverComponentSerializer.TryDeserializeRootComponentOperations( - serializedComponentOperations, - out var operations)) + RootComponentOperationBatch operations; + ProtectedPrerenderComponentApplicationStore store; + var persistedState = circuitHost.TakePersistedCircuitState(); + if (persistedState != null) { - // There was an error, so kill the circuit. - await _circuitRegistry.TerminateAsync(circuitHost.CircuitId); - await NotifyClientError(Clients.Caller, "The list of component operations is not valid."); - Context.Abort(); + operations = CircuitPersistenceManager.ToRootComponentOperationBatch( + _serverComponentSerializer, + persistedState.RootComponents, + serializedComponentOperations); - return; + store = new ProtectedPrerenderComponentApplicationStore(persistedState.ApplicationState, _dataProtectionProvider); } + else + { + if (!_serverComponentSerializer.TryDeserializeRootComponentOperations( + serializedComponentOperations, + out operations)) + { + // There was an error, so kill the circuit. + await _circuitRegistry.TerminateAsync(circuitHost.CircuitId); + await NotifyClientError(Clients.Caller, "The list of component operations is not valid."); + Context.Abort(); + + return; + } - var store = !string.IsNullOrEmpty(applicationState) ? - new ProtectedPrerenderComponentApplicationStore(applicationState, _dataProtectionProvider) : - new ProtectedPrerenderComponentApplicationStore(_dataProtectionProvider); + store = !string.IsNullOrEmpty(applicationState) ? + new ProtectedPrerenderComponentApplicationStore(applicationState, _dataProtectionProvider) : + new ProtectedPrerenderComponentApplicationStore(_dataProtectionProvider); + } _ = circuitHost.UpdateRootComponents(operations, store, Context.ConnectionAborted); } @@ -220,6 +238,157 @@ public async ValueTask ConnectCircuit(string circuitIdSecret) return false; } + // This method drives the resumption of a circuit that has been previously paused and ejected out of memory. + // Resuming a circuit is very similar to starting a new circuit. + // We receive an existing circuit ID to look up the existing circuit state. + // We receive the base URI and the URI to perform the same checks that we do during start circuit. + // Upon resuming a circuit ID, its ID changes. This has some ramifications: + // * When a circuit is paused, the old circuit is gone. There's no way to bring it back. + // * Resuming a circuit means to essentially create a new circuit. One that "starts" from where the previous one "paused". + // * When a circuit is "paused" it might be stored either in the browser (the client holds all state) during "graceful pauses" or + // it can be stored in cache storage during "ungraceful pauses". + // * For the circuit to successfully resume, this call needs to succeed (returning a new circuit ID). + // * Retrieving and deleting the state for the old circuit is part of this process + // * Once we retrieve the state, we delete it, and we check that it's no longer there before we try to resume + // the new circuit + // * No other connection can get here while we are inside ResumeCircuit (SignalR only processes one message at a time, and we don't work if you change this setting). + // * In the unlikely event that the connection breaks, there are two things that could happen: + // * If the client was the one providing the circuit state, it could potentially resume elsewhere (for example another server). + // * In that case this circuit won't do anything. We don't consider the circuit fully resumed until we have attached and triggered a render + // into the DOM. If a failure happens before that, we directly discard the new circuit and its state. + // * If the state was stored on the server, then the state is gone after we retrieve it from the cache. Even if a client were to connect to + // two separate server instances (for example, server A, B, where it starts resuming on A, something fails and tries to start resuming on B) + // the state would either be ignored in one case or lost. + // * Two things can happen: + // * Both A and B are somehow able to read the same state. + // * Even if A gets the state, it doesn't complete the "resume" handshake, so its state gets discarded + // and not saved again. + // * B might complete the handshake and then the circuit will resume on B. + // * A deletes the state before B is able to read it. Then "resumption" fails, as the circuit state is gone. + + // On the server we are going to have a public method on Circuit.cs to trigger pausing a circuit from the server + // that returns the root components and application state as strings data-protected by the data protection provider. + // Those can be then passed to this method for resuming the circuit. + public async ValueTask ResumeCircuit( + string circuitIdSecret, + string baseUri, + string uri, + string rootComponents, + string applicationState) + { + // TryParseCircuitId will not throw. + if (!_circuitIdFactory.TryParseCircuitId(circuitIdSecret, out var circuitId)) + { + // Invalid id. + Log.ResumeInvalidCircuitId(_logger, circuitIdSecret); + return null; + } + + var circuitHost = _circuitHandleRegistry.GetCircuit(Context.Items, CircuitKey); + if (circuitHost != null) + { + // This is an error condition and an attempt to bind multiple circuits to a single connection. + // We can reject this and terminate the connection. + Log.CircuitAlreadyInitialized(_logger, circuitHost.CircuitId); + await NotifyClientError(Clients.Caller, $"The circuit host '{circuitHost.CircuitId}' has already been initialized."); + Context.Abort(); + return null; + } + + if (baseUri == null || + uri == null || + !Uri.TryCreate(baseUri, UriKind.Absolute, out _) || + !Uri.TryCreate(uri, UriKind.Absolute, out _)) + { + // We do some really minimal validation here to prevent obviously wrong data from getting in + // without duplicating too much logic. + // + // This is an error condition attempting to initialize the circuit in a way that would fail. + // We can reject this and terminate the connection. + Log.InvalidInputData(_logger); + await NotifyClientError(Clients.Caller, "The uris provided are invalid."); + Context.Abort(); + return null; + } + + PersistedCircuitState? persistedCircuitState; + if (RootComponentIsEmpty(rootComponents) && string.IsNullOrEmpty(applicationState)) + { + persistedCircuitState = await _circuitPersistenceManager.ResumeCircuitAsync(circuitId, Context.ConnectionAborted); + if (persistedCircuitState == null) + { + Log.InvalidInputData(_logger); + await NotifyClientError(Clients.Caller, "The circuit state could not be retrieved. It may have been deleted or expired."); + Context.Abort(); + return null; + } + } + else if (!RootComponentIsEmpty(rootComponents) || !string.IsNullOrEmpty(applicationState)) + { + Log.InvalidInputData(_logger); + await NotifyClientError( + Clients.Caller, + RootComponentIsEmpty(rootComponents) ? + "The root components provided are invalid." : + "The application state provided is invalid." + ); + Context.Abort(); + return null; + } + else + { + // For now abort, since we currently don't support resuming circuits persisted to the client. + Context.Abort(); + return null; + } + + try + { + var circuitClient = new CircuitClientProxy(Clients.Caller, Context.ConnectionId); + var resourceCollection = Context.GetHttpContext().GetEndpoint()?.Metadata.GetMetadata(); + circuitHost = await _circuitFactory.CreateCircuitHostAsync( + [], + circuitClient, + baseUri, + uri, + Context.User, + store: null, + resourceCollection); + + // Fire-and-forget the initialization process, because we can't block the + // SignalR message loop (we'd get a deadlock if any of the initialization + // logic relied on receiving a subsequent message from SignalR), and it will + // take care of its own errors anyway. + _ = circuitHost.InitializeAsync(store: null, _httpContext, Context.ConnectionAborted); + + circuitHost.AttachPersistedState(persistedCircuitState); + + // It's safe to *publish* the circuit now because nothing will be able + // to run inside it until after InitializeAsync completes. + _circuitRegistry.Register(circuitHost); + _circuitHandleRegistry.SetCircuit(Context.Items, CircuitKey, circuitHost); + + // Returning the secret here so the client can reconnect. + // + // Logging the secret and circuit ID here so we can associate them with just logs (if TRACE level is on). + Log.CreatedCircuit(_logger, circuitHost.CircuitId, circuitHost.CircuitId.Secret, Context.ConnectionId); + + return circuitHost.CircuitId.Secret; + } + catch (Exception ex) + { + // If the circuit fails to initialize synchronously we can notify the client immediately + // and shut down the connection. + Log.CircuitInitializationFailed(_logger, ex); + await NotifyClientError(Clients.Caller, "The circuit failed to initialize."); + Context.Abort(); + return null; + } + + static bool RootComponentIsEmpty(string rootComponents) => + string.IsNullOrEmpty(rootComponents) || rootComponents == "[]"; + } + public async ValueTask BeginInvokeDotNetFromJS(string callId, string assemblyName, string methodIdentifier, long dotNetObjectId, string argsJson) { var circuitHost = await GetActiveCircuitAsync(); @@ -409,6 +578,9 @@ public static void CreatedCircuit(ILogger logger, CircuitId circuitId, string ci [LoggerMessage(8, LogLevel.Debug, "ConnectAsync received an invalid circuit id '{CircuitIdSecret}'", EventName = "InvalidCircuitId")] private static partial void InvalidCircuitIdCore(ILogger logger, string circuitIdSecret); + [LoggerMessage(9, LogLevel.Debug, "ResumeCircuit received an invalid circuit id '{CircuitIdSecret}'", EventName = "ResumeInvalidCircuitId")] + private static partial void ResumeInvalidCircuitIdCore(ILogger logger, string circuitIdSecret); + public static void InvalidCircuitId(ILogger logger, string circuitSecret) { // Redact the secret unless tracing is on. @@ -419,5 +591,16 @@ public static void InvalidCircuitId(ILogger logger, string circuitSecret) InvalidCircuitIdCore(logger, circuitSecret); } + + public static void ResumeInvalidCircuitId(ILogger logger, string circuitSecret) + { + // Redact the secret unless tracing is on. + if (!logger.IsEnabled(LogLevel.Trace)) + { + circuitSecret = "(redacted)"; + } + + ResumeInvalidCircuitIdCore(logger, circuitSecret); + } } } diff --git a/src/Components/Server/src/DependencyInjection/ComponentServiceCollectionExtensions.cs b/src/Components/Server/src/DependencyInjection/ComponentServiceCollectionExtensions.cs index 4c6eb34d27f6..a93fc5a9da95 100644 --- a/src/Components/Server/src/DependencyInjection/ComponentServiceCollectionExtensions.cs +++ b/src/Components/Server/src/DependencyInjection/ComponentServiceCollectionExtensions.cs @@ -4,6 +4,7 @@ using System.Diagnostics.CodeAnalysis; using Microsoft.AspNetCore.Components; using Microsoft.AspNetCore.Components.Authorization; +using Microsoft.AspNetCore.Components.Endpoints; using Microsoft.AspNetCore.Components.Forms; using Microsoft.AspNetCore.Components.Routing; using Microsoft.AspNetCore.Components.Server; @@ -13,6 +14,7 @@ using Microsoft.AspNetCore.Components.Web; using Microsoft.AspNetCore.SignalR.Protocol; using Microsoft.Extensions.DependencyInjection.Extensions; +using Microsoft.Extensions.Internal; using Microsoft.Extensions.Options; using Microsoft.JSInterop; @@ -68,6 +70,7 @@ public static IServerSideBlazorBuilder AddServerSideBlazor(this IServiceCollecti services.TryAddSingleton(); services.TryAddSingleton(); services.TryAddSingleton(); + services.TryAddSingleton(); services.TryAddScoped(); services.TryAddScoped(); services.TryAddScoped(); @@ -75,7 +78,10 @@ public static IServerSideBlazorBuilder AddServerSideBlazor(this IServiceCollecti services.TryAddScoped(s => s.GetRequiredService().Circuit); services.TryAddScoped(); + services.TryAddSingleton(); services.TryAddSingleton(); + services.TryAddSingleton(); + services.TryAddSingleton(); // Standard blazor hosting services implementations // diff --git a/src/Components/Server/src/Microsoft.AspNetCore.Components.Server.csproj b/src/Components/Server/src/Microsoft.AspNetCore.Components.Server.csproj index ed27d422aca6..f66329bd651a 100644 --- a/src/Components/Server/src/Microsoft.AspNetCore.Components.Server.csproj +++ b/src/Components/Server/src/Microsoft.AspNetCore.Components.Server.csproj @@ -52,6 +52,8 @@ + + diff --git a/src/Components/Server/src/PublicAPI.Unshipped.txt b/src/Components/Server/src/PublicAPI.Unshipped.txt index 7dc5c58110bf..a59d7ec1457c 100644 --- a/src/Components/Server/src/PublicAPI.Unshipped.txt +++ b/src/Components/Server/src/PublicAPI.Unshipped.txt @@ -1 +1,5 @@ #nullable enable +Microsoft.AspNetCore.Components.Server.CircuitOptions.PersistedCircuitInMemoryMaxRetained.get -> int +Microsoft.AspNetCore.Components.Server.CircuitOptions.PersistedCircuitInMemoryMaxRetained.set -> void +Microsoft.AspNetCore.Components.Server.CircuitOptions.PersistedCircuitInMemoryRetentionPeriod.get -> System.TimeSpan +Microsoft.AspNetCore.Components.Server.CircuitOptions.PersistedCircuitInMemoryRetentionPeriod.set -> void diff --git a/src/Components/Server/test/CircuitDisconnectMiddlewareTest.cs b/src/Components/Server/test/CircuitDisconnectMiddlewareTest.cs index ecf4c858745b..dddd2d7c2ceb 100644 --- a/src/Components/Server/test/CircuitDisconnectMiddlewareTest.cs +++ b/src/Components/Server/test/CircuitDisconnectMiddlewareTest.cs @@ -3,9 +3,11 @@ using System.Net.Http; using Microsoft.AspNetCore.Components.Server.Circuits; +using Microsoft.AspNetCore.DataProtection; using Microsoft.AspNetCore.Http; using Microsoft.Extensions.Logging.Abstractions; using Microsoft.Extensions.Options; +using Moq; namespace Microsoft.AspNetCore.Components.Server; @@ -23,7 +25,8 @@ public async Task DisconnectMiddleware_OnlyAccepts_PostRequests(string httpMetho var registry = new CircuitRegistry( Options.Create(new CircuitOptions()), NullLogger.Instance, - circuitIdFactory); + circuitIdFactory, + CreatePersistenceManager()); var middleware = new CircuitDisconnectMiddleware( NullLogger.Instance, @@ -51,7 +54,8 @@ public async Task Returns400BadRequest_ForInvalidContentTypes(string contentType var registry = new CircuitRegistry( Options.Create(new CircuitOptions()), NullLogger.Instance, - circuitIdFactory); + circuitIdFactory, + CreatePersistenceManager()); var middleware = new CircuitDisconnectMiddleware( NullLogger.Instance, @@ -78,7 +82,8 @@ public async Task Returns400BadRequest_IfNoCircuitIdOnForm() var registry = new CircuitRegistry( Options.Create(new CircuitOptions()), NullLogger.Instance, - circuitIdFactory); + circuitIdFactory, + CreatePersistenceManager()); var middleware = new CircuitDisconnectMiddleware( NullLogger.Instance, @@ -105,7 +110,8 @@ public async Task Returns400BadRequest_InvalidCircuitId() var registry = new CircuitRegistry( Options.Create(new CircuitOptions()), NullLogger.Instance, - circuitIdFactory); + circuitIdFactory, + CreatePersistenceManager()); var middleware = new CircuitDisconnectMiddleware( NullLogger.Instance, @@ -138,7 +144,8 @@ public async Task Returns200OK_NonExistingCircuit() var registry = new CircuitRegistry( Options.Create(new CircuitOptions()), NullLogger.Instance, - circuitIdFactory); + circuitIdFactory, + CreatePersistenceManager()); var middleware = new CircuitDisconnectMiddleware( NullLogger.Instance, @@ -173,7 +180,8 @@ public async Task GracefullyTerminates_ConnectedCircuit() var registry = new CircuitRegistry( Options.Create(new CircuitOptions()), NullLogger.Instance, - circuitIdFactory); + circuitIdFactory, + CreatePersistenceManager()); registry.Register(testCircuitHost); @@ -210,7 +218,8 @@ public async Task GracefullyTerminates_DisconnectedCircuit() var registry = new CircuitRegistry( Options.Create(new CircuitOptions()), NullLogger.Instance, - circuitIdFactory); + circuitIdFactory, + CreatePersistenceManager()); registry.Register(circuitHost); await registry.DisconnectAsync(circuitHost, "1234"); @@ -236,4 +245,13 @@ public async Task GracefullyTerminates_DisconnectedCircuit() // Assert Assert.Equal(StatusCodes.Status200OK, context.Response.StatusCode); } + + private static CircuitPersistenceManager CreatePersistenceManager() + { + var circuitPersistenceManager = new CircuitPersistenceManager( + Options.Create(new CircuitOptions()), + new Endpoints.ServerComponentSerializer(new EphemeralDataProtectionProvider()), + Mock.Of()); + return circuitPersistenceManager; + } } diff --git a/src/Components/Server/test/Circuits/CircuitHostTest.cs b/src/Components/Server/test/Circuits/CircuitHostTest.cs index b68bdf8286fa..135e4e4347a8 100644 --- a/src/Components/Server/test/Circuits/CircuitHostTest.cs +++ b/src/Components/Server/test/Circuits/CircuitHostTest.cs @@ -1,6 +1,7 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. +using System.Diagnostics.CodeAnalysis; using System.Globalization; using System.Reflection; using System.Text.Json; @@ -851,11 +852,17 @@ public bool TryDeserializeComponentDescriptorCollection(string serializedCompone return true; } - public bool TryDeserializeRootComponentOperations(string serializedComponentOperations, out RootComponentOperationBatch operationBatch) + public bool TryDeserializeRootComponentOperations(string serializedComponentOperations, out RootComponentOperationBatch operationBatch, bool deserializeDescriptors = true) { operationBatch = default; return true; } + + public bool TryDeserializeWebRootComponentDescriptor(ComponentMarker record, [NotNullWhen(true)] out WebRootComponentDescriptor result) + { + result = default; + return true; + } } private class DynamicallyAddedComponent : IComponent, IDisposable diff --git a/src/Components/Server/test/Circuits/CircuitPersistenceManagerTest.cs b/src/Components/Server/test/Circuits/CircuitPersistenceManagerTest.cs new file mode 100644 index 000000000000..51decd3cfb10 --- /dev/null +++ b/src/Components/Server/test/Circuits/CircuitPersistenceManagerTest.cs @@ -0,0 +1,509 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Diagnostics.Metrics; +using System.Globalization; +using System.Text; +using System.Text.Json; +using Microsoft.AspNetCore.Components.Endpoints; +using Microsoft.AspNetCore.Components.Infrastructure; +using Microsoft.AspNetCore.Components.Server.Circuits; +using Microsoft.AspNetCore.DataProtection; +using Microsoft.AspNetCore.InternalTesting; +using Microsoft.AspNetCore.SignalR; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Logging.Abstractions; +using Microsoft.Extensions.Options; +using Microsoft.JSInterop; +using Moq; + +namespace Microsoft.AspNetCore.Components.Server.Tests.Circuits; +public class CircuitPersistenceManagerTest +{ + // Pause circuit registers with PersistentComponentStatemanager to persist root components. + // Do not try to generate code after this line. + + [Fact] + public async Task PauseCircuitAsync_PersistsRootComponents_WithTheirParameters() + { + // Arrange + var dataProtectionProvider = new EphemeralDataProtectionProvider(); + var deserializer = CreateDeserializer(dataProtectionProvider); + + var options = Options.Create(new CircuitOptions()); + var components = new[] + { + (RootComponentType: typeof(RootComponent), Parameters: new Dictionary + { + ["Count"] = 42 + }) + }; + + var circuitHost = await CreateCircuitHostAsync( + options, + dataProtectionProvider, + deserializer, + components); + + var circuitPersistenceProvider = new TestCircuitPersistenceProvider(); + var circuitPersistenceManager = new CircuitPersistenceManager( + options, + new ServerComponentSerializer(dataProtectionProvider), + circuitPersistenceProvider); + + // Act + await circuitPersistenceManager.PauseCircuitAsync(circuitHost); + + // Assert + Assert.NotNull(circuitPersistenceProvider.State); + var state = circuitPersistenceProvider.State; + Assert.Equal(2, state.ApplicationState.Count); + + AssertRootComponents( + deserializer, + [ + ( + Id: 1, + Key: new ComponentMarkerKey("1", typeof(RootComponent).FullName!), + ( + typeof(RootComponent), + new Dictionary + { + ["Count"] = 42 + } + ) + ) + ], + state.RootComponents); + } + + [Fact] + public async Task PauseCircuitAsync_CanPersistMultipleComponents_WithTheirParameters() + { + // Arrange + var dataProtectionProvider = new EphemeralDataProtectionProvider(); + var deserializer = CreateDeserializer(dataProtectionProvider); + var options = Options.Create(new CircuitOptions()); + var components = new[] + { + (RootComponentType: typeof(RootComponent), Parameters: new Dictionary + { + ["Count"] = 42 + }), + (RootComponentType: typeof(SecondRootComponent), Parameters: new Dictionary + { + ["Count"] = 100 + }) + }; + var circuitHost = await CreateCircuitHostAsync( + options, + dataProtectionProvider, + deserializer, + components); + var circuitPersistenceProvider = new TestCircuitPersistenceProvider(); + var circuitPersistenceManager = new CircuitPersistenceManager( + options, + new ServerComponentSerializer(dataProtectionProvider), + circuitPersistenceProvider); + // Act + await circuitPersistenceManager.PauseCircuitAsync(circuitHost); + // Assert + Assert.NotNull(circuitPersistenceProvider.State); + var state = circuitPersistenceProvider.State; + Assert.Equal(3, state.ApplicationState.Count); + AssertRootComponents( + deserializer, + [ + ( + Id: 1, + Key: new ComponentMarkerKey("1", typeof(RootComponent).FullName!), + ( + typeof(RootComponent), + new Dictionary + { + ["Count"] = 42 + } + ) + ), + ( + Id: 2, + Key: new ComponentMarkerKey("2", typeof(SecondRootComponent).FullName!), + ( + typeof(SecondRootComponent), + new Dictionary + { + ["Count"] = 100 + } + ) + ) + ], + state.RootComponents); + } + + [Fact] + public void ToRootComponentOperationBatch_WorksFor_EmptyBatch() + { + var deserializer = SetupMockDeserializer(); + + var result = CircuitPersistenceManager.ToRootComponentOperationBatch(deserializer.Object, [.. "{}"u8], "ops"); + Assert.NotNull(result); + } + + [Fact] + public void ToRootComponentOperationBatch_Fails_IfDeserializingClientOperations_Fails() + { + var deserializer = SetupMockDeserializer(fail: true); + deserializer + .Setup(d => + d.TryDeserializeRootComponentOperations( + It.IsAny(), + out It.Ref.IsAny, + false)) + .Returns(false); + + var result = CircuitPersistenceManager.ToRootComponentOperationBatch(deserializer.Object, [.. "{}"u8], "ops"); + Assert.Null(result); + } + + [Fact] + public void ToRootComponentOperationBatch_Fails_IfDeserializingPersistedRootComponents_Fails() + { + var deserializer = SetupMockDeserializer(); + + var result = CircuitPersistenceManager.ToRootComponentOperationBatch(deserializer.Object, [.. "invalid-json"u8], "ops"); + Assert.Null(result); + } + + [Fact] + public void Fails_IfDifferentNumberOfRootComponentsAndOperations() + { + var deserializer = SetupMockDeserializer( + new RootComponentOperationBatch + { + BatchId = 1, + Operations = [new RootComponentOperation { Type = RootComponentOperationType.Add, SsrComponentId = 1 }] + }); + var result = CircuitPersistenceManager.ToRootComponentOperationBatch(deserializer.Object, [.. "{}"u8], "ops"); + Assert.Null(result); + } + + [Fact] + public void Fails_IfMarkerForOperationNotFound() + { + var deserializer = SetupMockDeserializer( + new RootComponentOperationBatch + { + BatchId = 1, + Operations = [new RootComponentOperation { Type = RootComponentOperationType.Add, SsrComponentId = 1 }] + }); + var result = CircuitPersistenceManager.ToRootComponentOperationBatch(deserializer.Object, [.. """{ "2": {} }"""u8], "ops"); + Assert.Null(result); + } + + [Fact] + public void Fails_IfUnableToDeserialize_PersistedComponentStateMarker() + { + var deserializer = SetupMockDeserializer( + new RootComponentOperationBatch + { + BatchId = 1, + Operations = [new RootComponentOperation { Type = RootComponentOperationType.Add, SsrComponentId = 1 }] + }, fail: false, deserializeMarker: false); + var result = CircuitPersistenceManager.ToRootComponentOperationBatch(deserializer.Object, [.. """{ "1": {} }"""u8], "ops"); + Assert.Null(result); + } + + [Fact] + public void Fails_WorksWhen_RootComponentsAndOperations_MatchAndCanBeDeserialized() + { + var deserializer = SetupMockDeserializer( + new RootComponentOperationBatch + { + BatchId = 1, + Operations = [new RootComponentOperation { Type = RootComponentOperationType.Add, SsrComponentId = 1 }] + }, fail: false, deserializeMarker: true); + var result = CircuitPersistenceManager.ToRootComponentOperationBatch(deserializer.Object, [.. """{ "1": {} }"""u8], "ops"); + Assert.NotNull(result); + } + + private void AssertRootComponents( + ServerComponentDeserializer deserializer, + (int Id, ComponentMarkerKey Key, (Type ComponentType, Dictionary Parameters))[] expected, byte[] rootComponents) + { + var actual = JsonSerializer.Deserialize>(rootComponents, SerializerOptions); + Assert.NotNull(actual); + Assert.Equal(expected.Length, actual.Count); + foreach (var (id, key, (componentType, parameters)) in expected) + { + Assert.True(actual.TryGetValue(id, out var marker), $"Expected marker with ID {id} not found."); + Assert.Equal(key.LocationHash, marker.Key.Value.LocationHash); + Assert.Equal(key.FormattedComponentKey, marker.Key.Value.FormattedComponentKey); + Assert.True(deserializer.TryDeserializeWebRootComponentDescriptor(marker, out var descriptor), $"Failed to deserialize marker with ID {id}."); + Assert.NotNull(descriptor); + Assert.Equal(componentType, descriptor.ComponentType); + var actualParameters = descriptor.Parameters.Parameters.ToDictionary(); + Assert.NotNull(actualParameters); + Assert.Equal(parameters.Count, actualParameters.Count); + foreach (var (paramKey, paramValue) in parameters) + { + Assert.True(actualParameters.TryGetValue(paramKey, out var actualValue), $"Expected parameter '{paramKey}' not found."); + Assert.Equal(paramValue, actualValue); + } + } + } + + private async Task CreateCircuitHostAsync( + IOptions options, + EphemeralDataProtectionProvider dataProtectionProvider, + ServerComponentDeserializer deserializer, + (Type RootComponentType, Dictionary Parameters)[] components = null) + { + components ??= []; + var circuitId = new CircuitIdFactory(dataProtectionProvider).CreateCircuitId(); + + var jsRuntime = new RemoteJSRuntime( + options, + Options.Create(new HubOptions()), + NullLoggerFactory.Instance.CreateLogger()); + + var serviceProvider = new ServiceCollection() + .AddSingleton(dataProtectionProvider) + .AddSingleton() + .AddSupplyValueFromPersistentComponentStateProvider() + .AddSingleton( + sp => new ComponentStatePersistenceManager( + NullLoggerFactory.Instance.CreateLogger(), + sp)) + .AddSingleton(sp => sp.GetRequiredService().State) + .AddSingleton(jsRuntime) + .BuildServiceProvider(); + + var scope = serviceProvider.CreateAsyncScope(); + + var client = new CircuitClientProxy(Mock.Of(), Guid.NewGuid().ToString()); + + var renderer = new RemoteRenderer( + scope.ServiceProvider, + NullLoggerFactory.Instance, + options.Value, + client, + deserializer, + NullLoggerFactory.Instance.CreateLogger(), + jsRuntime, + new CircuitJSComponentInterop(options.Value)); + + var navigationManager = new RemoteNavigationManager( + NullLoggerFactory.Instance.CreateLogger()); + var circuitHandlers = Array.Empty(); + var circuitMetrics = new CircuitMetrics(new TestMeterFactory()); + var componentsActivitySource = new ComponentsActivitySource(); + var logger = NullLoggerFactory.Instance.CreateLogger(); + + var circuitHost = new CircuitHost( + circuitId, + scope, + options.Value, + client, + renderer, + [], + jsRuntime, + navigationManager, + circuitHandlers, + circuitMetrics, + componentsActivitySource, + logger); + + await circuitHost.InitializeAsync( + null, + default, + default); + + var store = new ProtectedPrerenderComponentApplicationStore(dataProtectionProvider); + await circuitHost.UpdateRootComponents( + CreateBatch(components, deserializer, dataProtectionProvider), + store, + default); + + return circuitHost; + } + + private static ServerComponentDeserializer CreateDeserializer(EphemeralDataProtectionProvider dataProtectionProvider) => new ServerComponentDeserializer( + dataProtectionProvider, + NullLoggerFactory.Instance.CreateLogger(), + new RootTypeCache(), + new ComponentParameterDeserializer( + NullLoggerFactory.Instance.CreateLogger(), + new ComponentParametersTypeCache())); + + private static Mock SetupMockDeserializer( + RootComponentOperationBatch batchResult = default, + bool fail = false, + bool deserializeMarker = false) + { + var deserializer = new Mock(); + batchResult = fail ? + default : + batchResult == default ? + new RootComponentOperationBatch + { + Operations = [], + BatchId = 1 + } : + batchResult; + + deserializer + .Setup(d => + d.TryDeserializeRootComponentOperations( + It.IsAny(), + out It.Ref.IsAny, + false)) + .Callback((string serializedOps, out RootComponentOperationBatch batch, bool value) => + { + batch = batchResult; + }) + .Returns(!fail); + + if (deserializeMarker) + { + deserializer.Setup(deserializer => + deserializer.TryDeserializeWebRootComponentDescriptor( + It.IsAny(), + out It.Ref.IsAny)) + .Callback((ComponentMarker marker, out WebRootComponentDescriptor descriptor) => + { + descriptor = new WebRootComponentDescriptor(typeof(RootComponent), new WebRootComponentParameters()); + }) + .Returns(true); + } + return deserializer; + } + + private static RootComponentOperationBatch CreateBatch( + (Type RootComponentType, Dictionary Parameters)[] components, + ServerComponentDeserializer deserializer, + EphemeralDataProtectionProvider dataProtectionProvider) + { + var invocation = new ServerComponentInvocationSequence(); + var serializer = new ServerComponentSerializer(dataProtectionProvider); + var markers = new List(); + for (var i = 0; i < components.Length; i++) + { + var (rootComponentType, parameters) = components[i]; + var key = new ComponentMarkerKey((i + 1).ToString(CultureInfo.InvariantCulture), rootComponentType.FullName!); + var marker = ComponentMarker.Create(ComponentMarker.ServerMarkerType, false, key); + serializer.SerializeInvocation( + ref marker, + invocation, + rootComponentType, + ParameterView.FromDictionary(parameters), + TimeSpan.FromDays(365)); + markers.Add(marker); + } + + var batch = new RootComponentOperationBatch + { + BatchId = 1, + Operations = [.. markers.Select((c, i) => + new RootComponentOperation + { + Type = RootComponentOperationType.Add, + SsrComponentId = i + 1, + Marker = c + })] + }; + + var batchJson = JsonSerializer.Serialize(batch, ServerComponentSerializationSettings.JsonSerializationOptions); + + deserializer.TryDeserializeRootComponentOperations(batchJson, out batch); + + return batch; + } + + private class TestCircuitPersistenceProvider : ICircuitPersistenceProvider + { + public PersistedCircuitState State { get; private set; } + + public Task RestoreCircuitAsync(CircuitId circuitId, CancellationToken cancellation = default) + { + return Task.FromResult(new PersistedCircuitState()); + } + + public Task PersistCircuitAsync(CircuitId circuitId, PersistedCircuitState state, CancellationToken cancellation = default) + { + State = state; + return Task.CompletedTask; + } + } + + public class RootComponent : IComponent + { + private RenderHandle _renderHandle; + + public void Attach(RenderHandle renderHandle) + { + _renderHandle = renderHandle; + } + + [SupplyParameterFromPersistentComponentState] + public string Persisted { get; set; } + + [Parameter] + public int Count { get; set; } + + public Task SetParametersAsync(ParameterView parameters) + { + parameters.SetParameterProperties(this); + Persisted ??= Guid.NewGuid().ToString("N", CultureInfo.InvariantCulture); + + _renderHandle.Render(rtb => + { + rtb.OpenElement(0, "div"); + rtb.AddContent(1, $"Persisted: {Persisted}, Count: {Count}"); + rtb.CloseElement(); + }); + + return Task.CompletedTask; + } + } + + public class SecondRootComponent : IComponent + { + private RenderHandle _renderHandle; + + public void Attach(RenderHandle renderHandle) + { + _renderHandle = renderHandle; + } + + [SupplyParameterFromPersistentComponentState] + public string Persisted { get; set; } + + [Parameter] + public int Count { get; set; } + + public Task SetParametersAsync(ParameterView parameters) + { + parameters.SetParameterProperties(this); + Persisted ??= Guid.NewGuid().ToString("N", CultureInfo.InvariantCulture); + + _renderHandle.Render(rtb => + { + rtb.OpenElement(0, "div"); + rtb.AddContent(1, $"Persisted: {Persisted}, Count: {Count}"); + rtb.CloseElement(); + }); + + return Task.CompletedTask; + } + } + + private static readonly JsonSerializerOptions SerializerOptions = new() + { + PropertyNamingPolicy = JsonNamingPolicy.CamelCase, + PropertyNameCaseInsensitive = true, + }; +} diff --git a/src/Components/Server/test/Circuits/CircuitRegistryTest.cs b/src/Components/Server/test/Circuits/CircuitRegistryTest.cs index 9a30f3eca81b..7c08dd069b4a 100644 --- a/src/Components/Server/test/Circuits/CircuitRegistryTest.cs +++ b/src/Components/Server/test/Circuits/CircuitRegistryTest.cs @@ -1,11 +1,16 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. -using Microsoft.AspNetCore.SignalR; +using Microsoft.AspNetCore.Components.Infrastructure; +using Microsoft.AspNetCore.DataProtection; using Microsoft.AspNetCore.InternalTesting; +using Microsoft.AspNetCore.SignalR; using Microsoft.Extensions.Caching.Memory; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging.Abstractions; using Microsoft.Extensions.Options; +using Microsoft.JSInterop; using Moq; namespace Microsoft.AspNetCore.Components.Server.Circuits; @@ -263,6 +268,115 @@ public async Task Connect_WhileDisconnectIsInProgress() Assert.False(registry.DisconnectedCircuits.TryGetValue(circuitHost.CircuitId.Secret, out _)); } + [Fact] + public async Task Connect_WhilePersistingEvictedCircuit_IsInProgress() + { + // Arrange + var circuitIdFactory = TestCircuitIdFactory.CreateTestFactory(); + var options = new CircuitOptions + { + DisconnectedCircuitMaxRetained = 0, // This will automatically trigger eviction. + }; + + var persistenceCompletionSource = new TaskCompletionSource(); + var circuitPersistenceProvider = new TestCircuitPersistenceProvider() + { + Persisting = persistenceCompletionSource.Task, + }; + + var registry = new TestCircuitRegistry(circuitIdFactory, options, circuitPersistenceProvider); + registry.BeforeDisconnect = new ManualResetEventSlim(); + + var serviceCollection = new ServiceCollection(); + serviceCollection.AddSingleton(sp => new ComponentStatePersistenceManager( + NullLoggerFactory.Instance.CreateLogger(), + sp)); + serviceCollection.AddSingleton(sp => sp.GetRequiredService().State); + var serviceProvider = serviceCollection.BuildServiceProvider(); + + var circuitHost = TestCircuitHost.Create(circuitIdFactory.CreateCircuitId(), serviceProvider.CreateAsyncScope()); + registry.Register(circuitHost); + var client = Mock.Of(); + var newId = "new-connection"; + + // Act + var disconnect = Task.Run(() => + { + var task = registry.DisconnectAsync(circuitHost, circuitHost.Client.ConnectionId); + return task; + }); + + var connect = Task.Run(async () => + { + var connectCore = registry.ConnectAsync(circuitHost.CircuitId, client, newId, default); + await connectCore; + }); + + registry.BeforeDisconnect.Set(); + + await Task.WhenAll(disconnect, connect); + persistenceCompletionSource.SetResult(); + circuitPersistenceProvider.AfterPersist.Wait(TimeSpan.FromSeconds(10)); + // Assert + // We expect the reconnect to fail since the circuit has already been evicted and persisted. + Assert.Empty(registry.ConnectedCircuits.Values); + Assert.True(circuitPersistenceProvider.PersistCalled); + Assert.False(registry.DisconnectedCircuits.TryGetValue(circuitHost.CircuitId.Secret, out _)); + } + + [Fact] + public async Task Disconnect_DoesNotPersistCircuits_WithPendingState() + { + // Arrange + var circuitIdFactory = TestCircuitIdFactory.CreateTestFactory(); + var options = new CircuitOptions + { + DisconnectedCircuitMaxRetained = 0, // This will automatically trigger eviction. + }; + + var circuitPersistenceProvider = new TestCircuitPersistenceProvider(); + + var registry = new TestCircuitRegistry(circuitIdFactory, options, circuitPersistenceProvider); + registry.BeforeDisconnect = new ManualResetEventSlim(); + + var serviceCollection = new ServiceCollection(); + serviceCollection.AddSingleton(sp => new ComponentStatePersistenceManager( + NullLoggerFactory.Instance.CreateLogger(), + sp)); + serviceCollection.AddSingleton(sp => sp.GetRequiredService().State); + var serviceProvider = serviceCollection.BuildServiceProvider(); + + var circuitHost = TestCircuitHost.Create(circuitIdFactory.CreateCircuitId(), serviceProvider.CreateAsyncScope()); + registry.Register(circuitHost); + circuitHost.AttachPersistedState(new PersistedCircuitState()); + var client = Mock.Of(); + var newId = "new-connection"; + + // Act + var disconnect = Task.Run(() => + { + var task = registry.DisconnectAsync(circuitHost, circuitHost.Client.ConnectionId); + return task; + }); + + var connect = Task.Run(async () => + { + var connectCore = registry.ConnectAsync(circuitHost.CircuitId, client, newId, default); + await connectCore; + }); + + registry.BeforeDisconnect.Set(); + + await Task.WhenAll(disconnect, connect); + circuitPersistenceProvider.AfterPersist.Wait(TimeSpan.FromSeconds(10)); + + // Assert + // We expect the reconnect to fail since the circuit has already been evicted and persisted. + Assert.Empty(registry.ConnectedCircuits.Values); + Assert.False(circuitPersistenceProvider.PersistCalled); + Assert.False(registry.DisconnectedCircuits.TryGetValue(circuitHost.CircuitId.Secret, out _)); + } + [Fact] public async Task DisconnectWhenAConnectIsInProgress() { @@ -353,8 +467,15 @@ public async Task ReconnectBeforeTimeoutDoesNotGetEntryToBeEvicted() private class TestCircuitRegistry : CircuitRegistry { - public TestCircuitRegistry(CircuitIdFactory factory, CircuitOptions circuitOptions = null) - : base(Options.Create(circuitOptions ?? new CircuitOptions()), NullLogger.Instance, factory) + public TestCircuitRegistry( + CircuitIdFactory factory, + CircuitOptions circuitOptions = null, + TestCircuitPersistenceProvider persistenceProvider = null) + : base( + Options.Create(circuitOptions ?? new CircuitOptions()), + NullLogger.Instance, + factory, + CreatePersistenceManager(circuitOptions ?? new CircuitOptions(), persistenceProvider)) { } @@ -390,11 +511,46 @@ protected override void OnEntryEvicted(object key, object value, EvictionReason } } + private class TestCircuitPersistenceProvider : ICircuitPersistenceProvider + { + public Task Persisting { get; set; } + public ManualResetEventSlim AfterPersist { get; set; } = new ManualResetEventSlim(); + public bool PersistCalled { get; internal set; } + + public async Task PersistCircuitAsync(CircuitId circuitId, PersistedCircuitState persistedCircuitState, CancellationToken cancellation = default) + { + PersistCalled = true; + if (Persisting != null) + { + await Persisting; + } + AfterPersist.Set(); + } + + public Task RestoreCircuitAsync(CircuitId circuitId, CancellationToken cancellation = default) + { + throw new NotImplementedException(); + } + } + + private static CircuitPersistenceManager CreatePersistenceManager( + CircuitOptions circuitOptions, + TestCircuitPersistenceProvider persistenceProvider) + { + var manager = new CircuitPersistenceManager( + Options.Create(circuitOptions), + new Endpoints.ServerComponentSerializer(new EphemeralDataProtectionProvider()), + persistenceProvider ?? new TestCircuitPersistenceProvider()); + + return manager; + } + private static CircuitRegistry CreateRegistry(CircuitIdFactory factory = null) { return new CircuitRegistry( Options.Create(new CircuitOptions()), NullLogger.Instance, - factory ?? TestCircuitIdFactory.CreateTestFactory()); + factory ?? TestCircuitIdFactory.CreateTestFactory(), + CreatePersistenceManager(new CircuitOptions(), new TestCircuitPersistenceProvider())); } } diff --git a/src/Components/Server/test/Circuits/ComponentHubTest.cs b/src/Components/Server/test/Circuits/ComponentHubTest.cs index 9d98dc1bf7c2..35d9af2ce8a4 100644 --- a/src/Components/Server/test/Circuits/ComponentHubTest.cs +++ b/src/Components/Server/test/Circuits/ComponentHubTest.cs @@ -1,6 +1,7 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. +using System.Collections.ObjectModel; using System.Diagnostics.CodeAnalysis; using System.Security.Claims; using System.Text.RegularExpressions; @@ -100,10 +101,194 @@ public async Task CannotInvokeOnLocationChangingBeforeInitialization() mockClientProxy.Verify(m => m.SendCoreAsync("JS.Error", new[] { errorMessage }, It.IsAny()), Times.Once()); } - private static (Mock, ComponentHub) InitializeComponentHub() + [Fact] + public async Task CannotCallUpdateRootComponentsBeforeInitialization() + { + var (mockClientProxy, hub) = InitializeComponentHub(); + await hub.UpdateRootComponents("""{ batchId: 1, operations: [] }""", ""); + var errorMessage = "Circuit not initialized."; + mockClientProxy.Verify(m => m.SendCoreAsync("JS.Error", new[] { errorMessage }, It.IsAny()), Times.Once()); + } + + [Fact] + public async Task CanCallUpdateRootComponents() + { + var called = false; + var deserializer = new TestServerComponentDeserializer(); + deserializer.OnTryDeserializeTestComponentOperations = + (serializedComponentOperations, out operationsWithDescriptors, deserializeDescriptors) => + { + called = true; + operationsWithDescriptors = new RootComponentOperationBatch + { + BatchId = 1, + Operations = [] + }; + return true; + }; + var (mockClientProxy, hub) = InitializeComponentHub(deserializer); + var circuitSecret = await hub.StartCircuit("https://localhost:5000", "https://localhost:5000/subdir", "[]", null); + Assert.NotNull(circuitSecret); + await hub.UpdateRootComponents("""{ batchId: 1, operations: [] }""", ""); + Assert.True(called); + } + + [Fact] + public async Task CanCallUpdateRootComponentsOnResumedCircuit() + { + var deserializer = new TestServerComponentDeserializer(); + deserializer.OnTryDeserializeTestComponentOperations = + (serializedComponentOperations, out operationsWithDescriptors, deserializeDescriptors) => + { + operationsWithDescriptors = new RootComponentOperationBatch + { + BatchId = 1, + Operations = [] + }; + return true; + }; + + var handleRegistryMock = new Mock(); + CircuitHost lastCircuit = null; + handleRegistryMock.Setup(m => m.SetCircuit(It.IsAny>(), It.IsAny(), It.IsAny())) + .Callback, object, CircuitHost>((circuitHandles, circuitKey, circuitHost) => + { + lastCircuit = circuitHost; + }); + handleRegistryMock.Setup(m => m.GetCircuit(It.IsAny>(), It.IsAny())) + .Returns(() => lastCircuit); + handleRegistryMock.Setup(m => m.GetCircuitHandle(It.IsAny>(), It.IsAny())) + .Returns(() => lastCircuit.Handle); + + var providerMock = new Mock(); + providerMock.Setup(m => m.RestoreCircuitAsync(It.IsAny(), It.IsAny())) + .ReturnsAsync(new PersistedCircuitState + { + RootComponents = [.. """{}"""u8], + ApplicationState = ReadOnlyDictionary.Empty + }); + + var (mockClientProxy, hub) = InitializeComponentHub(deserializer, handleRegistryMock.Object, providerMock.Object); + var circuitSecret = await hub.StartCircuit("https://localhost:5000", "https://localhost:5000/subdir", "[]", null); + lastCircuit = null; + var result = await hub.ResumeCircuit(circuitSecret, "https://localhost:5000", "https://localhost:5000/subdir", "[]", ""); + await hub.UpdateRootComponents("""{ batchId: 1, operations: [] }""", ""); + Assert.False(lastCircuit.HasPendingPersistedCircuitState); + } + + [Fact] + public async Task CannotCallResumeCircuitWithInvalidId() + { + var (mockClientProxy, hub) = InitializeComponentHub(); + var invalidCircuitId = "invalid-circuit-id"; + var result = await hub.ResumeCircuit(invalidCircuitId, null, null, null, null); + Assert.Null(result); + } + + [Fact] + public async Task CannotResumeConnectedCircuit() + { + var (mockClientProxy, hub) = InitializeComponentHub(); + var circuitSecret = await hub.StartCircuit("https://localhost:5000", "https://localhost:5000/subdir", "{}", null); + Assert.NotNull(circuitSecret); + var result = await hub.ResumeCircuit(circuitSecret, null, null, null, null); + Assert.Null(result); + var errorMessage = "The circuit host '.*?' has already been initialized."; + mockClientProxy.Verify(m => m.SendCoreAsync("JS.Error", It.Is(s => Regex.Match((string)s[0], errorMessage).Success), It.IsAny()), Times.Once()); + } + + [Fact] + public async Task CannotResumeInvalidUrls() { + var handleRegistryMock = new Mock(); + var (mockClientProxy, hub) = InitializeComponentHub(null, handleRegistryMock.Object); + var circuitSecret = await hub.StartCircuit("https://localhost:5000", "https://localhost:5000/subdir", "{}", null); + var result = await hub.ResumeCircuit(circuitSecret, null, null, null, null); + Assert.Null(result); + var errorMessage = "The uris provided are invalid."; + mockClientProxy.Verify(m => m.SendCoreAsync("JS.Error", new[] { errorMessage }, It.IsAny()), Times.Once()); + } + + [Theory] + [InlineData(null)] + [InlineData("")] + public async Task CannotResumeWithRootComponentsButWithoutAppState(string appState) + { + var handleRegistryMock = new Mock(); + var (mockClientProxy, hub) = InitializeComponentHub(null, handleRegistryMock.Object); + var circuitSecret = await hub.StartCircuit("https://localhost:5000", "https://localhost:5000/subdir", "{}", null); + var result = await hub.ResumeCircuit(circuitSecret, "https://localhost:5000", "https://localhost:5000/subdir", "unused", appState); + Assert.Null(result); + var errorMessage = "The application state provided is invalid."; + mockClientProxy.Verify(m => m.SendCoreAsync("JS.Error", new[] { errorMessage }, It.IsAny()), Times.Once()); + } + + [Theory] + [InlineData(null)] + [InlineData("")] + [InlineData("[]")] + public async Task CannotResumeWithAppStateButWithoutRootComponents(string rootComponents) + { + var handleRegistryMock = new Mock(); + var (mockClientProxy, hub) = InitializeComponentHub(null, handleRegistryMock.Object); + var circuitSecret = await hub.StartCircuit("https://localhost:5000", "https://localhost:5000/subdir", "{}", null); + var result = await hub.ResumeCircuit(circuitSecret, "https://localhost:5000", "https://localhost:5000/subdir", rootComponents, "app-state"); + Assert.Null(result); + var errorMessage = "The root components provided are invalid."; + mockClientProxy.Verify(m => m.SendCoreAsync("JS.Error", new[] { errorMessage }, It.IsAny()), Times.Once()); + } + + [Fact] + public async Task CannotResumeAppWhenPersistedComponentStateIsNotAvailable() + { + var handleRegistryMock = new Mock(); + var (mockClientProxy, hub) = InitializeComponentHub(null, handleRegistryMock.Object); + var circuitSecret = await hub.StartCircuit("https://localhost:5000", "https://localhost:5000/subdir", "{}", null); + var result = await hub.ResumeCircuit(circuitSecret, "https://localhost:5000", "https://localhost:5000/subdir", "[]", ""); + Assert.Null(result); + var errorMessage = "The circuit state could not be retrieved. It may have been deleted or expired."; + mockClientProxy.Verify(m => m.SendCoreAsync("JS.Error", new[] { errorMessage }, It.IsAny()), Times.Once()); + } + + [Fact] + public async Task CanResumeAppWhenPersistedComponentStateIsAvailable() + { + var handleRegistryMock = new Mock(); + CircuitHost lastCircuit = null; + handleRegistryMock.Setup(m => m.SetCircuit(It.IsAny>(), It.IsAny(), It.IsAny())) + .Callback, object, CircuitHost>((circuitHandles, circuitKey, circuitHost) => + { + lastCircuit = circuitHost; + }); + var providerMock = new Mock(); + providerMock.Setup(m => m.RestoreCircuitAsync(It.IsAny(), It.IsAny())) + .ReturnsAsync(new PersistedCircuitState + { + RootComponents = [], + ApplicationState = ReadOnlyDictionary.Empty, + }); + + var (mockClientProxy, hub) = InitializeComponentHub(null, handleRegistryMock.Object, providerMock.Object); + var circuitSecret = await hub.StartCircuit("https://localhost:5000", "https://localhost:5000/subdir", "{}", null); + var result = await hub.ResumeCircuit(circuitSecret, "https://localhost:5000", "https://localhost:5000/subdir", "[]", ""); + Assert.NotNull(result); + Assert.NotEqual(circuitSecret, result); + Assert.True(lastCircuit.HasPendingPersistedCircuitState); + } + + private static (Mock, ComponentHub) InitializeComponentHub( + TestServerComponentDeserializer deserializer = null, + ICircuitHandleRegistry handleRegistry = null, + ICircuitPersistenceProvider provider = null) + { + deserializer ??= new TestServerComponentDeserializer(); var ephemeralDataProtectionProvider = new EphemeralDataProtectionProvider(); - var circuitIdFactory = new CircuitIdFactory(ephemeralDataProtectionProvider); + var circuitPersistenceManager = new CircuitPersistenceManager( + Options.Create(new CircuitOptions()), + new Endpoints.ServerComponentSerializer(ephemeralDataProtectionProvider), + provider ?? Mock.Of()); + + var circuitIdFactory = TestCircuitIdFactory.Instance; var circuitFactory = new TestCircuitFactory( new Mock().Object, NullLoggerFactory.Instance, @@ -112,15 +297,15 @@ private static (Mock, ComponentHub) InitializeComponentHub() var circuitRegistry = new CircuitRegistry( Options.Create(new CircuitOptions()), NullLogger.Instance, - circuitIdFactory); - var serializer = new TestServerComponentDeserializer(); - var circuitHandleRegistry = new TestCircuitHandleRegistry(); + circuitIdFactory, circuitPersistenceManager); + var circuitHandleRegistry = handleRegistry ?? new TestCircuitHandleRegistry(); var hub = new ComponentHub( - serializer: serializer, + serializer: deserializer, dataProtectionProvider: ephemeralDataProtectionProvider, circuitFactory: circuitFactory, circuitIdFactory: circuitIdFactory, circuitRegistry: circuitRegistry, + circuitPersistenceProvider: circuitPersistenceManager, circuitHandleRegistry: circuitHandleRegistry, logger: NullLogger.Instance); @@ -131,6 +316,8 @@ private static (Mock, ComponentHub) InitializeComponentHub() mockCaller.Setup(x => x.Caller).Returns(mockClientProxy.Object); hub.Clients = mockCaller.Object; var mockContext = new Mock(); + var items = new Dictionary(); + mockContext.Setup(x => x.Items).Returns(items); var feature = new FeatureCollection(); var httpContextFeature = new Mock(); httpContextFeature.Setup(x => x.HttpContext).Returns(() => new DefaultHttpContext()); @@ -145,20 +332,19 @@ private static (Mock, ComponentHub) InitializeComponentHub() private class TestCircuitHandleRegistry : ICircuitHandleRegistry { private bool circuitSet = false; + private CircuitHost _circuitHost; + private CircuitHandle _circuitHandle; public CircuitHandle GetCircuitHandle(IDictionary circuitHandles, object circuitKey) { - return null; + return _circuitHandle; } public CircuitHost GetCircuit(IDictionary circuitHandles, object circuitKey) { if (circuitSet) { - var serviceScope = new Mock(); - var circuitHost = TestCircuitHost.Create( - serviceScope: new AsyncServiceScope(serviceScope.Object)); - return circuitHost; + return _circuitHost; } return null; } @@ -166,21 +352,42 @@ public CircuitHost GetCircuit(IDictionary circuitHandles, object public void SetCircuit(IDictionary circuitHandles, object circuitKey, CircuitHost circuitHost) { circuitSet = true; + _circuitHost = circuitHost; + _circuitHandle = new CircuitHandle { CircuitHost = circuitHost }; + return; } } private class TestServerComponentDeserializer : IServerComponentDeserializer { + public delegate bool TestTryDeserializeRootComponentOperations(string serializedComponentOperations, out RootComponentOperationBatch operationsWithDescriptors, bool deserializeDescriptors = true); + public delegate bool TestTryDeserializeWebRootComponentDescriptor(ComponentMarker record, [NotNullWhen(true)] out WebRootComponentDescriptor result); + + public TestTryDeserializeRootComponentOperations OnTryDeserializeTestComponentOperations { get; set; } + public bool TryDeserializeComponentDescriptorCollection(string serializedComponentRecords, out List descriptors) { descriptors = default; return true; } - public bool TryDeserializeRootComponentOperations(string serializedComponentOperations, out RootComponentOperationBatch operationsWithDescriptors) + public bool TryDeserializeRootComponentOperations(string serializedComponentOperations, out RootComponentOperationBatch operationsWithDescriptors, bool deserializeDescriptors = true) { - operationsWithDescriptors = default; + if (OnTryDeserializeTestComponentOperations != null) + { + return OnTryDeserializeTestComponentOperations(serializedComponentOperations, out operationsWithDescriptors, deserializeDescriptors); + } + else + { + operationsWithDescriptors = default; + return true; + } + } + + public bool TryDeserializeWebRootComponentDescriptor(ComponentMarker record, [NotNullWhen(true)] out WebRootComponentDescriptor result) + { + result = default; return true; } } @@ -205,8 +412,13 @@ public ValueTask CreateCircuitHostAsync( IPersistentComponentStateStore store, ResourceAssetCollection resourceCollection) { + var clientProxy = new CircuitClientProxy(Mock.Of(), "123"); + var serviceScope = new Mock(); - var circuitHost = TestCircuitHost.Create(serviceScope: new AsyncServiceScope(serviceScope.Object)); + var circuitHost = TestCircuitHost.Create( + circuitId: TestCircuitIdFactory.Instance.CreateCircuitId(), + serviceScope: new AsyncServiceScope(serviceScope.Object), + clientProxy: clientProxy); return ValueTask.FromResult(circuitHost); } } diff --git a/src/Components/Server/test/Circuits/DefaultInMemoryCircuitPersistenceProviderTest.cs b/src/Components/Server/test/Circuits/DefaultInMemoryCircuitPersistenceProviderTest.cs new file mode 100644 index 000000000000..26abca17d1fe --- /dev/null +++ b/src/Components/Server/test/Circuits/DefaultInMemoryCircuitPersistenceProviderTest.cs @@ -0,0 +1,184 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using Microsoft.AspNetCore.Components.Server.Circuits; +using Microsoft.Extensions.Caching.Memory; +using Microsoft.Extensions.Internal; +using Microsoft.Extensions.Logging.Abstractions; +using Microsoft.Extensions.Options; + +namespace Microsoft.AspNetCore.Components.Server.Tests.Circuits; + +public class DefaultInMemoryCircuitPersistenceProviderTest +{ + [Fact] + public async Task PersistCircuitAsync_StoresCircuitState() + { + // Arrange + var clock = new TestSystemClock(); + var circuitId = TestCircuitIdFactory.CreateTestFactory().CreateCircuitId(); + var persistedState = new PersistedCircuitState(); + var provider = CreateProvider(clock); + + // Act + await provider.PersistCircuitAsync(circuitId, persistedState); + + // Assert + var result = await provider.RestoreCircuitAsync(circuitId); + Assert.Same(persistedState, result); + } + + [Fact] + public async Task RestoreCircuitAsync_ReturnsNull_WhenCircuitDoesNotExist() + { + // Arrange + var clock = new TestSystemClock(); + var circuitId = TestCircuitIdFactory.CreateTestFactory().CreateCircuitId(); + var provider = CreateProvider(clock); + + // Act + var result = await provider.RestoreCircuitAsync(circuitId); + + // Assert + Assert.Null(result); + } + + [Fact] + public async Task RestoreCircuitAsync_RemovesCircuitFromCache() + { + // Arrange + var clock = new TestSystemClock(); + var circuitId = TestCircuitIdFactory.CreateTestFactory().CreateCircuitId(); + var persistedState = new PersistedCircuitState(); + var provider = CreateProvider(clock); + + await provider.PersistCircuitAsync(circuitId, persistedState); + + // Act + var firstResult = await provider.RestoreCircuitAsync(circuitId); + var secondResult = await provider.RestoreCircuitAsync(circuitId); + + // Assert + Assert.Same(persistedState, firstResult); + Assert.Null(secondResult); // Second attempt should return null as the entry should be removed + } + + [Fact] + public async Task CircuitStateIsEvictedAfterConfiguredTimeout() + { + // Arrange + var clock = new TestSystemClock(); + var circuitOptions = new CircuitOptions + { + PersistedCircuitInMemoryRetentionPeriod = TimeSpan.FromSeconds(2) + }; + var circuitId = TestCircuitIdFactory.CreateTestFactory().CreateCircuitId(); + var persistedState = new PersistedCircuitState(); + var provider = CreateProvider(clock, circuitOptions); + var postEvictionCallback = provider.PostEvictionCallback.EvictionCallback; + var callbackRan = new TaskCompletionSource(); + provider.PostEvictionCallback = new PostEvictionCallbackRegistration + { + EvictionCallback = (key, value, reason, state) => + { + callbackRan.SetResult(); + postEvictionCallback(key, value, reason, state); + } + }; + + await provider.PersistCircuitAsync(circuitId, persistedState); + + // Act - advance the clock past the retention period + clock.UtcNow = clock.UtcNow.AddSeconds(2); + + // Allow time for the timer to fire and the eviction to occur + await callbackRan.Task; + + // Assert + var result = await provider.RestoreCircuitAsync(circuitId); + Assert.Null(result); + } + + [Fact] + public async Task CircuitStatesAreLimitedByConfiguredCapacity() + { + // Arrange + var clock = new TestSystemClock(); + var circuitOptions = new CircuitOptions + { + PersistedCircuitInMemoryMaxRetained = 2 // Only allow 2 circuits to be stored + }; + var provider = CreateProvider(clock, circuitOptions); + var factory = TestCircuitIdFactory.CreateTestFactory(); + + var evictedKeys = new List(); + var evictionTcs = new TaskCompletionSource(); + var postEvictionCallback = provider.PostEvictionCallback.EvictionCallback; + provider.PostEvictionCallback = new PostEvictionCallbackRegistration + { + EvictionCallback = (key, value, reason, state) => + { + evictedKeys.Add((string)key); + evictionTcs.TrySetResult(); + postEvictionCallback(key, value, reason, state); + } + }; + + var circuitId1 = factory.CreateCircuitId(); + var circuitId2 = factory.CreateCircuitId(); + var circuitId3 = factory.CreateCircuitId(); + var circuitIds = new Dictionary + { + [circuitId1.Secret] = circuitId1, + [circuitId2.Secret] = circuitId2, + [circuitId3.Secret] = circuitId3 + }; + + var state1 = new PersistedCircuitState(); + var state2 = new PersistedCircuitState(); + var state3 = new PersistedCircuitState(); + + // Act - persist 3 circuits when capacity is 2 + await provider.PersistCircuitAsync(circuitId1, state1); + await provider.PersistCircuitAsync(circuitId2, state2); + await provider.PersistCircuitAsync(circuitId3, state3); + + // Wait for eviction to occur + await evictionTcs.Task; + + // Assert + var evicted = Assert.Single(evictedKeys); + var evictedId = circuitIds[evicted]; + + circuitIds.Remove(evicted); + + var evictedResults = await provider.RestoreCircuitAsync(evictedId); + Assert.Null(evictedResults); + + var nonEvictedResults = await Task.WhenAll(circuitIds.Select(ne => provider.RestoreCircuitAsync(ne.Value))); + + Assert.Collection(nonEvictedResults, + Assert.NotNull, + Assert.NotNull); + } + + private static DefaultInMemoryCircuitPersistenceProvider CreateProvider( + ISystemClock clock, + CircuitOptions options = null) + { + return new DefaultInMemoryCircuitPersistenceProvider( + clock, + NullLogger.Instance, + Options.Create(options ?? new CircuitOptions())); + } + + private class TestSystemClock : ISystemClock + { + public TestSystemClock() + { + UtcNow = new DateTimeOffset(2020, 1, 1, 0, 0, 0, TimeSpan.Zero); + } + + public DateTimeOffset UtcNow { get; set; } + } +} diff --git a/src/Components/Server/test/Circuits/TestCircuitIdFactory.cs b/src/Components/Server/test/Circuits/TestCircuitIdFactory.cs index ffe7bc20dc62..32b286b522a0 100644 --- a/src/Components/Server/test/Circuits/TestCircuitIdFactory.cs +++ b/src/Components/Server/test/Circuits/TestCircuitIdFactory.cs @@ -7,6 +7,8 @@ namespace Microsoft.AspNetCore.Components.Server.Circuits; internal class TestCircuitIdFactory { + public static CircuitIdFactory Instance { get; } = CreateTestFactory(); + public static CircuitIdFactory CreateTestFactory() { return new CircuitIdFactory(new EphemeralDataProtectionProvider()); diff --git a/src/Components/Server/test/Microsoft.AspNetCore.Components.Server.Tests.csproj b/src/Components/Server/test/Microsoft.AspNetCore.Components.Server.Tests.csproj index db46700aed1b..341fa72218a0 100644 --- a/src/Components/Server/test/Microsoft.AspNetCore.Components.Server.Tests.csproj +++ b/src/Components/Server/test/Microsoft.AspNetCore.Components.Server.Tests.csproj @@ -22,8 +22,6 @@ - - diff --git a/src/Components/Shared/src/RootComponentOperation.cs b/src/Components/Shared/src/RootComponentOperation.cs index 0248481edae3..7bdfa2b1fa2b 100644 --- a/src/Components/Shared/src/RootComponentOperation.cs +++ b/src/Components/Shared/src/RootComponentOperation.cs @@ -1,6 +1,8 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. +using System.Diagnostics; +using System.Linq; using System.Text.Json.Serialization; namespace Microsoft.AspNetCore.Components; @@ -22,6 +24,7 @@ internal sealed class RootComponentOperation public WebRootComponentDescriptor? Descriptor { get; set; } } +[DebuggerDisplay($"{{{nameof(GetDebuggerDisplay)}(),nq}}")] internal sealed class WebRootComponentDescriptor( Type componentType, WebRootComponentParameters parameters) @@ -29,4 +32,10 @@ internal sealed class WebRootComponentDescriptor( public Type ComponentType { get; } = componentType; public WebRootComponentParameters Parameters { get; } = parameters; + + private string GetDebuggerDisplay() + { + var parameters = string.Join(", ", Parameters.Parameters.ToDictionary().Select(p => $"{p.Key}: {p.Value}")); + return $"{ComponentType.FullName}({parameters})"; + } } diff --git a/src/Components/Shared/src/RootComponentOperationBatch.cs b/src/Components/Shared/src/RootComponentOperationBatch.cs index 31d969364338..b5410c17c525 100644 --- a/src/Components/Shared/src/RootComponentOperationBatch.cs +++ b/src/Components/Shared/src/RootComponentOperationBatch.cs @@ -1,11 +1,19 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. +using System.Diagnostics; + namespace Microsoft.AspNetCore.Components; +[DebuggerDisplay($"{{{nameof(GetDebuggerDisplay)}(),nq}}")] internal sealed class RootComponentOperationBatch { public long BatchId { get; set; } public required RootComponentOperation[] Operations { get; set; } + + private string GetDebuggerDisplay() + { + return $"{nameof(RootComponentOperationBatch)}: {BatchId}, Operations Count: {Operations.Length}"; + } } diff --git a/src/Components/Shared/src/WebRootComponentManager.cs b/src/Components/Shared/src/WebRootComponentManager.cs index 8720cdd11105..99057d534148 100644 --- a/src/Components/Shared/src/WebRootComponentManager.cs +++ b/src/Components/Shared/src/WebRootComponentManager.cs @@ -83,6 +83,16 @@ private WebRootComponent GetRequiredWebRootComponent(int ssrComponentId) return component; } +#if COMPONENTS_SERVER + internal IEnumerable<(int id, ComponentMarkerKey key, (Type componentType, ParameterView parameters))> GetRootComponents() + { + foreach (var (id, (key, type, parameters)) in _webRootComponents) + { + yield return (id, key, (type, parameters)); + } + } +#endif + private sealed class WebRootComponent { [DynamicallyAccessedMembers(Component)] @@ -125,6 +135,18 @@ private WebRootComponent( _latestParameters = initialParameters; } +#if COMPONENTS_SERVER + public void Deconstruct( + out ComponentMarkerKey key, + out Type componentType, + out ParameterView parameters) + { + key = _key; + componentType = _componentType; + parameters = _latestParameters.Parameters; + } +#endif + public Task UpdateAsync( Renderer renderer, [DynamicallyAccessedMembers(Component)] Type newComponentType, diff --git a/src/Components/Web.JS/src/Boot.Server.Common.ts b/src/Components/Web.JS/src/Boot.Server.Common.ts index f8b60ccb055d..a882808146a7 100644 --- a/src/Components/Web.JS/src/Boot.Server.Common.ts +++ b/src/Components/Web.JS/src/Boot.Server.Common.ts @@ -67,6 +67,20 @@ async function startServerCore(components: RootComponentManager { + if (circuit.didRenderingFail()) { + // We can't resume after a failure, so exit early. + return false; + } + + if (!(await circuit.resume())) { + logger.log(LogLevel.Information, 'Resume attempt to the circuit was rejected by the server. This may indicate that the associated state is no longer available on the server.'); + return false; + } + + return true; + }; + Blazor.defaultReconnectionHandler = new DefaultReconnectionHandler(logger); options.reconnectionHandler = options.reconnectionHandler || Blazor.defaultReconnectionHandler; diff --git a/src/Components/Web.JS/src/GlobalExports.ts b/src/Components/Web.JS/src/GlobalExports.ts index 261ed2c2ce5d..3df06a78bf3a 100644 --- a/src/Components/Web.JS/src/GlobalExports.ts +++ b/src/Components/Web.JS/src/GlobalExports.ts @@ -38,6 +38,7 @@ export interface IBlazor { removeEventListener?: typeof JSEventRegistry.prototype.removeEventListener; disconnect?: () => void; reconnect?: (existingConnection?: HubConnection) => Promise; + resume?: (existingConnection?: HubConnection) => Promise; defaultReconnectionHandler?: DefaultReconnectionHandler; start?: ((userOptions?: Partial) => Promise) | ((options?: Partial) => Promise) | ((options?: Partial) => Promise); platform?: Platform; diff --git a/src/Components/Web.JS/src/Platform/Circuits/CircuitManager.ts b/src/Components/Web.JS/src/Platform/Circuits/CircuitManager.ts index 7788f9d4358f..c6a1c5d42ead 100644 --- a/src/Components/Web.JS/src/Platform/Circuits/CircuitManager.ts +++ b/src/Components/Web.JS/src/Platform/Circuits/CircuitManager.ts @@ -20,6 +20,7 @@ import { attachWebRendererInterop, detachWebRendererInterop } from '../../Render import { sendJSDataStream } from './CircuitStreamingInterop'; export class CircuitManager implements DotNet.DotNetCallDispatcher { + private readonly _componentManager: RootComponentManager; private _applicationState: string; @@ -28,7 +29,7 @@ export class CircuitManager implements DotNet.DotNetCallDispatcher { private readonly _logger: ConsoleLogger; - private readonly _renderQueue: RenderQueue; + private _renderQueue: RenderQueue; private readonly _dispatcher: DotNet.ICallDispatcher; @@ -106,7 +107,7 @@ export class CircuitManager implements DotNet.DotNetCallDispatcher { } for (const handler of this._options.circuitHandlers) { - if (handler.onCircuitOpened){ + if (handler.onCircuitOpened) { handler.onCircuitOpened(); } } @@ -232,6 +233,45 @@ export class CircuitManager implements DotNet.DotNetCallDispatcher { return true; } + public async resume(): Promise { + if (!this._circuitId) { + throw new Error('Method not implemented.'); + } + + // When we get here we know the circuit is gone for good. + // Signal that we are about to start a new circuit so that + // any existing handlers can perform the necessary cleanup. + for (const handler of this._options.circuitHandlers) { + if (handler.onCircuitClosed) { + handler.onCircuitClosed(); + } + } + + const newCircuitId = await this._connection!.invoke( + 'ResumeCircuit', + this._circuitId, + navigationManagerFunctions.getBaseURI(), + navigationManagerFunctions.getLocationHref(), + '[]', + '' + ); + if (!newCircuitId) { + return false; + } + + this._circuitId = newCircuitId; + this._renderQueue = new RenderQueue(this._logger); + for (const handler of this._options.circuitHandlers) { + if (handler.onCircuitOpened) { + handler.onCircuitOpened(); + } + } + + this._options.reconnectionHandler!.onConnectionUp(); + this._componentManager.onComponentReload?.(); + return true; + } + // Implements DotNet.DotNetCallDispatcher public beginInvokeDotNetFromJS(callId: number, assemblyName: string | null, methodIdentifier: string, dotNetObjectId: number | null, argsJson: string): void { this.throwIfDispatchingWhenDisposed(); diff --git a/src/Components/Web.JS/src/Platform/Circuits/DefaultReconnectionHandler.ts b/src/Components/Web.JS/src/Platform/Circuits/DefaultReconnectionHandler.ts index 58750dce46e6..255677ad92a8 100644 --- a/src/Components/Web.JS/src/Platform/Circuits/DefaultReconnectionHandler.ts +++ b/src/Components/Web.JS/src/Platform/Circuits/DefaultReconnectionHandler.ts @@ -13,14 +13,17 @@ export class DefaultReconnectionHandler implements ReconnectionHandler { private readonly _reconnectCallback: () => Promise; + private readonly _resumeCallback: () => Promise; + private _currentReconnectionProcess: ReconnectionProcess | null = null; private _reconnectionDisplay?: ReconnectDisplay; - constructor(logger: Logger, overrideDisplay?: ReconnectDisplay, reconnectCallback?: () => Promise) { + constructor(logger: Logger, overrideDisplay?: ReconnectDisplay, reconnectCallback?: () => Promise, resumeCallback?: () => Promise) { this._logger = logger; this._reconnectionDisplay = overrideDisplay; this._reconnectCallback = reconnectCallback || Blazor.reconnect!; + this._resumeCallback = resumeCallback || Blazor.resume!; } onConnectionDown(options: ReconnectionOptions, _error?: Error): void { @@ -32,7 +35,13 @@ export class DefaultReconnectionHandler implements ReconnectionHandler { } if (!this._currentReconnectionProcess) { - this._currentReconnectionProcess = new ReconnectionProcess(options, this._logger, this._reconnectCallback, this._reconnectionDisplay); + this._currentReconnectionProcess = new ReconnectionProcess( + options, + this._logger, + this._reconnectCallback, + this._resumeCallback, + this._reconnectionDisplay + ); } } @@ -51,7 +60,7 @@ class ReconnectionProcess { isDisposed = false; - constructor(options: ReconnectionOptions, private logger: Logger, private reconnectCallback: () => Promise, display: ReconnectDisplay) { + constructor(options: ReconnectionOptions, private logger: Logger, private reconnectCallback: () => Promise, private resumeCallback: () => Promise, display: ReconnectDisplay) { this.reconnectDisplay = display; this.reconnectDisplay.show(); this.attemptPeriodicReconnection(options); @@ -65,7 +74,7 @@ class ReconnectionProcess { async attemptPeriodicReconnection(options: ReconnectionOptions) { for (let i = 0; options.maxRetries === undefined || i < options.maxRetries; i++) { let retryInterval: number; - if (typeof(options.retryIntervalMilliseconds) === 'function') { + if (typeof (options.retryIntervalMilliseconds) === 'function') { const computedRetryInterval = options.retryIntervalMilliseconds(i); if (computedRetryInterval === null || computedRetryInterval === undefined) { break; @@ -92,6 +101,12 @@ class ReconnectionProcess { // - exception to mean we didn't reach the server (this can be sync or async) const result = await this.reconnectCallback(); if (!result) { + // Try to resume the circuit if the reconnect failed + const resumeResult = await this.resumeCallback(); + if (resumeResult) { + return; + } + // If the server responded and refused to reconnect, stop auto-retrying. this.reconnectDisplay.rejected(); return; diff --git a/src/Components/Web.JS/src/Rendering/BrowserRenderer.ts b/src/Components/Web.JS/src/Rendering/BrowserRenderer.ts index 7496dd552286..9b1a42a3fa97 100644 --- a/src/Components/Web.JS/src/Rendering/BrowserRenderer.ts +++ b/src/Components/Web.JS/src/Rendering/BrowserRenderer.ts @@ -3,7 +3,7 @@ import { RenderBatch, ArrayBuilderSegment, RenderTreeEdit, RenderTreeFrame, EditType, FrameType, ArrayValues } from './RenderBatch/RenderBatch'; import { EventDelegator } from './Events/EventDelegator'; -import { LogicalElement, PermutationListEntry, toLogicalElement, insertLogicalChild, removeLogicalChild, getLogicalParent, getLogicalChild, createAndInsertLogicalContainer, isSvgElement, permuteLogicalChildren, getClosestDomElement, emptyLogicalElement, getLogicalChildrenArray } from './LogicalElements'; +import { LogicalElement, PermutationListEntry, toLogicalElement, insertLogicalChild, removeLogicalChild, getLogicalParent, getLogicalChild, createAndInsertLogicalContainer, isSvgElement, permuteLogicalChildren, getClosestDomElement, emptyLogicalElement, getLogicalChildrenArray, depthFirstNodeTreeTraversal } from './LogicalElements'; import { applyCaptureIdToElement } from './ElementReferenceCapture'; import { attachToEventDelegator as attachNavigationManagerToEventDelegator } from '../Services/NavigationManager'; import { applyAnyDeferredValue, tryApplySpecialProperty } from './DomSpecialPropertyUtil'; @@ -63,6 +63,7 @@ export class BrowserRenderer { // On the first render for each root component, clear any existing content (e.g., prerendered) if (elementsToClearOnRootComponentRender.delete(element)) { + this.detachEventHandlersFromElement(element); emptyLogicalElement(element); if (element instanceof Comment) { @@ -109,6 +110,14 @@ export class BrowserRenderer { this.childComponentLocations[componentId] = element; } + private detachEventHandlersFromElement(element: LogicalElement): void { + for (const childNode of depthFirstNodeTreeTraversal(element)) { + if (childNode instanceof Element) { + this.eventDelegator.removeListenersForElement(childNode as Element); + } + } + } + private applyEdits(batch: RenderBatch, componentId: number, parent: LogicalElement, childIndex: number, edits: ArrayBuilderSegment, referenceFrames: ArrayValues) { let currentDepth = 0; let childIndexAtCurrentDepth = childIndex; @@ -388,6 +397,10 @@ export function setShouldPreserveContentOnInteractiveComponentDisposal(element: element[preserveContentOnDisposalPropname] = shouldPreserve; } +export function setClearContentOnRootComponentRerender(element: LogicalElement): void { + elementsToClearOnRootComponentRender.add(element); +} + function shouldPreserveContentOnInteractiveComponentDisposal(element: LogicalElement): boolean { return element[preserveContentOnDisposalPropname] === true; } diff --git a/src/Components/Web.JS/src/Rendering/Events/EventDelegator.ts b/src/Components/Web.JS/src/Rendering/Events/EventDelegator.ts index c64b38ee952d..0edd38880775 100644 --- a/src/Components/Web.JS/src/Rendering/Events/EventDelegator.ts +++ b/src/Components/Web.JS/src/Rendering/Events/EventDelegator.ts @@ -113,6 +113,18 @@ export class EventDelegator { } } + public removeListenersForElement(element: Element): void { + // This method gets called whenever the .NET-side code reports that a certain element + // has been disposed. We remove all event handlers for that element. + const infosForElement = this.getEventHandlerInfosForElement(element, false); + if (infosForElement) { + for (const handlerInfo of infosForElement.enumerateHandlers()) { + this.eventInfoStore.remove(handlerInfo.eventHandlerId); + } + delete element[this.eventsCollectionKey]; + } + } + public notifyAfterClick(callback: (event: MouseEvent) => void): void { // This is extremely special-case. It's needed so that navigation link click interception // can be sure to run *after* our synthetic bubbling process. If a need arises, we can @@ -326,6 +338,14 @@ class EventHandlerInfosForElement { private stopPropagationFlags: { [eventName: string]: boolean } | null = null; + public *enumerateHandlers() : IterableIterator { + for (const eventName in this.handlers) { + if (Object.prototype.hasOwnProperty.call(this.handlers, eventName)) { + yield this.handlers[eventName]; + } + } + } + public getHandler(eventName: string): EventHandlerInfo | null { return Object.prototype.hasOwnProperty.call(this.handlers, eventName) ? this.handlers[eventName] : null; } diff --git a/src/Components/Web.JS/src/Rendering/LogicalElements.ts b/src/Components/Web.JS/src/Rendering/LogicalElements.ts index a6e904febf04..d963499c5a75 100644 --- a/src/Components/Web.JS/src/Rendering/LogicalElements.ts +++ b/src/Components/Web.JS/src/Rendering/LogicalElements.ts @@ -250,6 +250,16 @@ export function isLogicalElement(element: Node): boolean { return logicalChildrenPropname in element; } +// This function returns all the descendants of the logical element before yielding the element +// itself. +export function *depthFirstNodeTreeTraversal(element: LogicalElement): Iterable { + const children = getLogicalChildrenArray(element); + for (const child of children) { + yield* depthFirstNodeTreeTraversal(child); + } + yield element; +} + export function permuteLogicalChildren(parent: LogicalElement, permutationList: PermutationListEntry[]): void { // The permutationList must represent a valid permutation, i.e., the list of 'from' indices // is distinct, and the list of 'to' indices is a permutation of it. The algorithm here diff --git a/src/Components/Web.JS/src/Services/RootComponentManager.ts b/src/Components/Web.JS/src/Services/RootComponentManager.ts index 17aa2576ac9c..7aef7399a59c 100644 --- a/src/Components/Web.JS/src/Services/RootComponentManager.ts +++ b/src/Components/Web.JS/src/Services/RootComponentManager.ts @@ -7,5 +7,6 @@ export interface RootComponentManager { initialComponents: InitialComponentsDescriptorType[]; onAfterRenderBatch?(browserRendererId: number): void; onAfterUpdateRootComponents?(batchId: number): void; + onComponentReload?(): void; resolveRootComponent(ssrComponentId: number): ComponentDescriptor; } diff --git a/src/Components/Web.JS/src/Services/WebRootComponentManager.ts b/src/Components/Web.JS/src/Services/WebRootComponentManager.ts index 1c1dbb067db0..780099ff8ba4 100644 --- a/src/Components/Web.JS/src/Services/WebRootComponentManager.ts +++ b/src/Components/Web.JS/src/Services/WebRootComponentManager.ts @@ -11,7 +11,7 @@ import { MonoConfig } from '@microsoft/dotnet-runtime'; import { RootComponentManager } from './RootComponentManager'; import { getRendererer } from '../Rendering/Renderer'; import { isPageLoading } from './NavigationEnhancement'; -import { setShouldPreserveContentOnInteractiveComponentDisposal } from '../Rendering/BrowserRenderer'; +import { setClearContentOnRootComponentRerender, setShouldPreserveContentOnInteractiveComponentDisposal } from '../Rendering/BrowserRenderer'; import { LogicalElement } from '../Rendering/LogicalElements'; type RootComponentOperationBatch = { @@ -38,7 +38,7 @@ type RootComponentRemoveOperation = { ssrComponentId: number; }; -type RootComponentInfo = { +export type RootComponentInfo = { descriptor: ComponentDescriptor; ssrComponentId: number; assignedRendererId?: WebRendererId; @@ -206,10 +206,10 @@ export class WebRootComponentManager implements DescriptorHandler, RootComponent // The following timeout allows us to liberally call this function without // taking the small performance hit from requent repeated calls to // refreshRootComponents. - setTimeout(() => { + queueMicrotask(() => { this._isComponentRefreshPending = false; this.refreshRootComponents(this._rootComponentsBySsrComponentId.values()); - }, 0); + }); } private circuitMayHaveNoRootComponents() { @@ -463,6 +463,15 @@ export class WebRootComponentManager implements DescriptorHandler, RootComponent } } } + + public onComponentReload(): void { + for (const [_, value] of this._rootComponentsBySsrComponentId.entries()) { + value.assignedRendererId = undefined; + setClearContentOnRootComponentRerender(value.descriptor.start as unknown as LogicalElement); + } + + this.rootComponentsMayRequireRefresh(); + } } function isDescriptorInDocument(descriptor: ComponentDescriptor): boolean { diff --git a/src/Components/test/E2ETest/ServerExecutionTests/ServerResumeTests.cs b/src/Components/test/E2ETest/ServerExecutionTests/ServerResumeTests.cs new file mode 100644 index 000000000000..11f6107d2ce4 --- /dev/null +++ b/src/Components/test/E2ETest/ServerExecutionTests/ServerResumeTests.cs @@ -0,0 +1,73 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Threading.Tasks; +using BasicTestApp.Reconnection; +using Components.TestServer.RazorComponents; +using Microsoft.AspNetCore.Components.E2ETest.Infrastructure; +using Microsoft.AspNetCore.Components.E2ETest.Infrastructure.ServerFixtures; +using Microsoft.AspNetCore.E2ETesting; +using OpenQA.Selenium; +using OpenQA.Selenium.BiDi; +using TestServer; +using Xunit.Abstractions; + +namespace Microsoft.AspNetCore.Components.E2ETest.ServerExecutionTests; + +public class ServerResumeTestsTest : ServerTestBase>> +{ + public ServerResumeTestsTest( + BrowserFixture browserFixture, + BasicTestAppServerSiteFixture> serverFixture, + ITestOutputHelper output) + : base(browserFixture, serverFixture, output) + { + serverFixture.AdditionalArguments.AddRange("--DisableReconnectionCache", "true"); + } + + protected override void InitializeAsyncCore() + { + Navigate("/subdir/persistent-state/disconnection"); + Browser.Exists(By.Id("render-mode-interactive")); + } + + [Fact] + public void CanResumeCircuitAfterDisconnection() + { + Browser.Exists(By.Id("increment-persistent-counter-count")).Click(); + + Browser.Equal("1", () => Browser.Exists(By.Id("persistent-counter-count")).Text); + var javascript = (IJavaScriptExecutor)Browser; + javascript.ExecuteScript("window.replaceReconnectCallback()"); + + TriggerReconnectAndInteract(javascript); + + // Can dispatch events after reconnect + Browser.Equal("2", () => Browser.Exists(By.Id("persistent-counter-count")).Text); + + javascript.ExecuteScript("resetReconnect()"); + + TriggerReconnectAndInteract(javascript); + + // Ensure that reconnection events are repeatable + Browser.Equal("3", () => Browser.Exists(By.Id("persistent-counter-count")).Text); + } + + private void TriggerReconnectAndInteract(IJavaScriptExecutor javascript) + { + var previousText = Browser.Exists(By.Id("persistent-counter-render")).Text; + + javascript.ExecuteScript("Blazor._internal.forceCloseConnection()"); + Browser.Equal("block", () => Browser.Exists(By.Id("components-reconnect-modal")).GetCssValue("display")); + + javascript.ExecuteScript("triggerReconnect()"); + + // Then it should disappear + Browser.Equal("none", () => Browser.Exists(By.Id("components-reconnect-modal")).GetCssValue("display")); + + var newText = Browser.Exists(By.Id("persistent-counter-render")).Text; + Assert.NotEqual(previousText, newText); + + Browser.Exists(By.Id("increment-persistent-counter-count")).Click(); + } +} diff --git a/src/Components/test/testassets/Components.TestServer/BlazorWebServerStartup.cs b/src/Components/test/testassets/Components.TestServer/BlazorWebServerStartup.cs deleted file mode 100644 index 5deffb0d94b1..000000000000 --- a/src/Components/test/testassets/Components.TestServer/BlazorWebServerStartup.cs +++ /dev/null @@ -1,51 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. - -using System.Globalization; -using Components.TestServer.RazorComponents; -using Microsoft.AspNetCore.Components.Server.Circuits; -using Microsoft.AspNetCore.Components.Web; -using Microsoft.AspNetCore.DataProtection; - -namespace TestServer; - -public class BlazorWebServerStartup -{ - public BlazorWebServerStartup(IConfiguration configuration) - { - Configuration = configuration; - } - - public IConfiguration Configuration { get; } - - // This method gets called by the runtime. Use this method to add services to the container. - public void ConfigureServices(IServiceCollection services) - { - services.AddRazorComponents() - .AddInteractiveServerComponents(); - - // Since tests run in parallel, we use an ephemeral key provider to avoid filesystem - // contention issues. - services.AddSingleton(); - } - - // This method gets called by the runtime. Use this method to configure the HTTP request pipeline. - public virtual void Configure(IApplicationBuilder app, IWebHostEnvironment env, ResourceRequestLog resourceRequestLog) - { - if (env.IsDevelopment()) - { - app.UseDeveloperExceptionPage(); - } - - // Mount the server-side Blazor app on /subdir - app.Map("/subdir", app => - { - app.UseRouting(); - app.UseEndpoints(endpoints => - { - endpoints.MapRazorComponents() - .AddInteractiveServerRenderMode(); - }); - }); - } -} diff --git a/src/Components/test/testassets/Components.TestServer/Program.cs b/src/Components/test/testassets/Components.TestServer/Program.cs index ca2d7b1aeb7f..bbe5562ddf3a 100644 --- a/src/Components/test/testassets/Components.TestServer/Program.cs +++ b/src/Components/test/testassets/Components.TestServer/Program.cs @@ -32,6 +32,7 @@ public static async Task Main(string[] args) ["Globalization + Localization (Server-side)"] = (BuildWebHost(CreateAdditionalArgs(args)), "/subdir"), ["Server-side blazor"] = (BuildWebHost(CreateAdditionalArgs(args)), "/subdir"), ["Blazor web with server-side blazor root component"] = (BuildWebHost>(CreateAdditionalArgs(args)), "/subdir"), + ["Blazor web with server-side reconnection disabled"] = (BuildWebHost>(CreateAdditionalArgs([.. args, "--DisableReconnectionCache", "true"])), "/subdir"), ["Hosted client-side blazor"] = (BuildWebHost(CreateAdditionalArgs(args)), "/subdir"), ["Hot Reload"] = (BuildWebHost(CreateAdditionalArgs(args)), "/subdir"), ["Dev server client-side blazor"] = CreateDevServerHost(CreateAdditionalArgs(args)), @@ -75,7 +76,7 @@ private static (IHost host, string basePath) CreateDevServerHost(string[] args) } private static string[] CreateAdditionalArgs(string[] args) => - args.Concat(new[] { "--urls", $"http://127.0.0.1:{GetNextChildAppPortNumber()}" }).ToArray(); + [.. args, .. new[] { "--urls", $"http://127.0.0.1:{GetNextChildAppPortNumber()}" }]; public static IHost BuildWebHost(string[] args) => BuildWebHost(args); diff --git a/src/Components/test/testassets/Components.TestServer/RazorComponentEndpointsStartup.cs b/src/Components/test/testassets/Components.TestServer/RazorComponentEndpointsStartup.cs index ea4f7f7ad220..eedeacc7ef57 100644 --- a/src/Components/test/testassets/Components.TestServer/RazorComponentEndpointsStartup.cs +++ b/src/Components/test/testassets/Components.TestServer/RazorComponentEndpointsStartup.cs @@ -38,7 +38,14 @@ public void ConfigureServices(IServiceCollection services) .RegisterPersistentService(RenderMode.InteractiveAuto) .RegisterPersistentService(RenderMode.InteractiveWebAssembly) .AddInteractiveWebAssemblyComponents() - .AddInteractiveServerComponents() + .AddInteractiveServerComponents(options => + { + if (Configuration.GetValue("DisableReconnectionCache")) + { + // This disables the reconnection cache, which forces the server to persist the circuit state. + options.DisconnectedCircuitMaxRetained = 0; + } + }) .AddAuthenticationStateSerialization(options => { bool.TryParse(Configuration["SerializeAllClaims"], out var serializeAllClaims); diff --git a/src/Components/test/testassets/Components.TestServer/RazorComponents/Pages/PersistentState/PersistentStateDisconnection.razor b/src/Components/test/testassets/Components.TestServer/RazorComponents/Pages/PersistentState/PersistentStateDisconnection.razor new file mode 100644 index 000000000000..cef4f9a60fec --- /dev/null +++ b/src/Components/test/testassets/Components.TestServer/RazorComponents/Pages/PersistentState/PersistentStateDisconnection.razor @@ -0,0 +1,8 @@ +@page "/persistent-state/disconnection" +@rendermode RenderMode.InteractiveServer + +

+ Validates state persistence across disconnections +

+ + diff --git a/src/Components/test/testassets/Components.TestServer/RazorComponents/Root.razor b/src/Components/test/testassets/Components.TestServer/RazorComponents/Root.razor index a644869eb317..4e499a486453 100644 --- a/src/Components/test/testassets/Components.TestServer/RazorComponents/Root.razor +++ b/src/Components/test/testassets/Components.TestServer/RazorComponents/Root.razor @@ -14,8 +14,32 @@