Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/PowerShellEditorServices/Server/PsesLanguageServer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ public async Task StartAsync()
.WithHandler<ShowHelpHandler>()
.WithHandler<ExpandAliasHandler>()
.WithHandler<PsesSemanticTokensHandler>()
.WithHandler<DidChangeWatchedFilesHandler>()
// NOTE: The OnInitialize delegate gets run when we first receive the
// _Initialize_ request:
// https://microsoft.github.io/language-server-protocol/specifications/specification-current/#initialize
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Copyright (c) Microsoft Corporation.
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.

using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Management.Automation;
Expand Down Expand Up @@ -57,6 +58,39 @@ public record struct AliasMap(
internal static readonly ConcurrentDictionary<string, List<string>> s_cmdletToAliasCache = new(System.StringComparer.OrdinalIgnoreCase);
internal static readonly ConcurrentDictionary<string, string> s_aliasToCmdletCache = new(System.StringComparer.OrdinalIgnoreCase);

/// <summary>
/// Gets the actual command behind a fully module qualified command invocation, e.g.
/// <c>Microsoft.PowerShell.Management\Get-ChildItem</c> will return <c>Get-ChilddItem</c>
/// </summary>
/// <param name="invocationName">
/// The potentially module qualified command name at the site of invocation.
/// </param>
/// <param name="moduleName">
/// A reference that will contain the module name if the invocation is module qualified.
/// </param>
/// <returns>The actual command name.</returns>
public static string StripModuleQualification(string invocationName, out ReadOnlyMemory<char> moduleName)
{
int slashIndex = invocationName.IndexOf('\\');
if (slashIndex is -1)
{
moduleName = default;
return invocationName;
}

// If '\' is the last character then it's probably not a module qualified command.
if (slashIndex == invocationName.Length - 1)
{
moduleName = default;
return invocationName;
}

// Storing moduleName as ROMemory safes a string allocation in the common case where it
// is not needed.
moduleName = invocationName.AsMemory().Slice(0, slashIndex);
return invocationName.Substring(slashIndex + 1);
}

/// <summary>
/// Gets the CommandInfo instance for a command with a particular name.
/// </summary>
Expand Down
100 changes: 100 additions & 0 deletions src/PowerShellEditorServices/Services/Symbols/ReferenceTable.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.

#nullable enable

using System;
using System.Collections.Concurrent;
using System.Management.Automation.Language;
using Microsoft.PowerShell.EditorServices.Services.TextDocument;
using Microsoft.PowerShell.EditorServices.Services.PowerShell.Utility;

namespace Microsoft.PowerShell.EditorServices.Services;

/// <summary>
/// Represents the symbols that are referenced and their locations within a single document.
/// </summary>
internal sealed class ReferenceTable
{
private readonly ScriptFile _parent;

private readonly ConcurrentDictionary<string, ConcurrentBag<IScriptExtent>> _symbolReferences = new(StringComparer.OrdinalIgnoreCase);

private bool _isInited;

public ReferenceTable(ScriptFile parent) => _parent = parent;

/// <summary>
/// Clears the reference table causing it to rescan the source AST when queried.
/// </summary>
public void TagAsChanged()
{
_symbolReferences.Clear();
_isInited = false;
}

// Prefer checking if the dictionary has contents to determine if initialized. The field
// `_isInited` is to guard against rescanning files with no command references, but will
// generally be less reliable of a check.
private bool IsInitialized => !_symbolReferences.IsEmpty || _isInited;

internal bool TryGetReferences(string command, out ConcurrentBag<IScriptExtent>? references)
{
EnsureInitialized();
return _symbolReferences.TryGetValue(command, out references);
}

internal void EnsureInitialized()
{
if (IsInitialized)
{
return;
}

_parent.ScriptAst.Visit(new ReferenceVisitor(this));
}

private void AddReference(string symbol, IScriptExtent extent)
{
_symbolReferences.AddOrUpdate(
symbol,
_ => new ConcurrentBag<IScriptExtent> { extent },
(_, existing) =>
{
existing.Add(extent);
return existing;
});
}

private sealed class ReferenceVisitor : AstVisitor
{
private readonly ReferenceTable _references;

public ReferenceVisitor(ReferenceTable references) => _references = references;

public override AstVisitAction VisitCommand(CommandAst commandAst)
{
string commandName = commandAst.GetCommandName();
if (string.IsNullOrEmpty(commandName))
{
return AstVisitAction.Continue;
}

_references.AddReference(
CommandHelpers.StripModuleQualification(commandName, out _),
commandAst.CommandElements[0].Extent);
return AstVisitAction.Continue;
}

public override AstVisitAction VisitVariableExpression(VariableExpressionAst variableExpressionAst)
{
// TODO: Consider tracking unscoped variable references only when they declared within
// the same function definition.
_references.AddReference(
$"${variableExpressionAst.VariablePath.UserPath}",
variableExpressionAst.Extent);

return AstVisitAction.Continue;
}
}
}
172 changes: 137 additions & 35 deletions src/PowerShellEditorServices/Services/Symbols/SymbolsService.cs
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ internal class SymbolsService

private readonly ConcurrentDictionary<string, ICodeLensProvider> _codeLensProviders;
private readonly ConcurrentDictionary<string, IDocumentSymbolProvider> _documentSymbolProviders;
private readonly ConfigurationService _configurationService;
#endregion

#region Constructors
Expand All @@ -65,6 +66,7 @@ public SymbolsService(
_runspaceContext = runspaceContext;
_executionService = executionService;
_workspaceService = workspaceService;
_configurationService = configurationService;

_codeLensProviders = new ConcurrentDictionary<string, ICodeLensProvider>();
if (configurationService.CurrentSettings.EnableReferencesCodeLens)
Expand Down Expand Up @@ -177,8 +179,15 @@ public async Task<List<SymbolReference>> FindReferencesOfSymbol(
_executionService,
cancellationToken).ConfigureAwait(false);

Dictionary<string, List<string>> cmdletToAliases = aliases.CmdletToAliases;
Dictionary<string, string> aliasToCmdlets = aliases.AliasToCmdlets;
string targetName = foundSymbol.SymbolName;
if (foundSymbol.SymbolType is SymbolType.Function)
{
targetName = CommandHelpers.StripModuleQualification(targetName, out _);
if (aliases.AliasToCmdlets.TryGetValue(foundSymbol.SymbolName, out string aliasDefinition))
{
targetName = aliasDefinition;
}
}

// We want to look for references first in referenced files, hence we use ordered dictionary
// TODO: File system case-sensitivity is based on filesystem not OS, but OS is a much cheaper heuristic
Expand All @@ -191,52 +200,63 @@ public async Task<List<SymbolReference>> FindReferencesOfSymbol(
fileMap[scriptFile.FilePath] = scriptFile;
}

foreach (string filePath in workspace.EnumeratePSFiles())
await ScanWorkspacePSFiles(cancellationToken).ConfigureAwait(false);

List<SymbolReference> symbolReferences = new();

// Using a nested method here to get a bit more readability and to avoid roslynator
// asserting we should use a giant nested ternary here.
static string[] GetIdentifiers(string symbolName, SymbolType symbolType, CommandHelpers.AliasMap aliases)
{
if (!fileMap.Contains(filePath))
if (symbolType is not SymbolType.Function)
{
// This async method is pretty dense with synchronous code
// so it's helpful to add some yields.
await Task.Yield();
cancellationToken.ThrowIfCancellationRequested();
if (!workspace.TryGetFile(filePath, out ScriptFile scriptFile))
{
// If we can't access the file for some reason, just ignore it
continue;
}
return new[] { symbolName };
}

fileMap[filePath] = scriptFile;
if (!aliases.CmdletToAliases.TryGetValue(symbolName, out List<string> foundAliasList))
{
return new[] { symbolName };
}

return foundAliasList.Prepend(symbolName)
.Distinct(StringComparer.OrdinalIgnoreCase)
.ToArray();
}

List<SymbolReference> symbolReferences = new();
foreach (object fileName in fileMap.Keys)
{
ScriptFile file = (ScriptFile)fileMap[fileName];
string[] allIdentifiers = GetIdentifiers(targetName, foundSymbol.SymbolType, aliases);

IEnumerable<SymbolReference> references = AstOperations.FindReferencesOfSymbol(
file.ScriptAst,
foundSymbol,
cmdletToAliases,
aliasToCmdlets);

foreach (SymbolReference reference in references)
foreach (ScriptFile file in _workspaceService.GetOpenedFiles())
{
foreach (string targetIdentifier in allIdentifiers)
{
try
if (!file.References.TryGetReferences(targetIdentifier, out ConcurrentBag<IScriptExtent> references))
{
reference.SourceLine = file.GetLine(reference.ScriptRegion.StartLineNumber);
continue;
}
catch (ArgumentOutOfRangeException e)

foreach (IScriptExtent extent in references)
{
reference.SourceLine = string.Empty;
_logger.LogException("Found reference is out of range in script file", e);
SymbolReference reference = new(
SymbolType.Function,
foundSymbol.SymbolName,
extent);

try
{
reference.SourceLine = file.GetLine(reference.ScriptRegion.StartLineNumber);
}
catch (ArgumentOutOfRangeException e)
{
reference.SourceLine = string.Empty;
_logger.LogException("Found reference is out of range in script file", e);
}
reference.FilePath = file.FilePath;
symbolReferences.Add(reference);
}
reference.FilePath = file.FilePath;
symbolReferences.Add(reference);
}

await Task.Yield();
cancellationToken.ThrowIfCancellationRequested();
await Task.Yield();
cancellationToken.ThrowIfCancellationRequested();
}
}

return symbolReferences;
Expand Down Expand Up @@ -495,6 +515,59 @@ await CommandHelpers.GetCommandInfoAsync(
return foundDefinition;
}

private Task _workspaceScanCompleted;

private async Task ScanWorkspacePSFiles(CancellationToken cancellationToken = default)
{
if (_configurationService.CurrentSettings.AnalyzeOpenDocumentsOnly)
{
return;
}

Task scanTask = _workspaceScanCompleted;
// It's not impossible for two scans to start at once but it should be exceedingly
// unlikely, and shouldn't break anything if it happens to. So we can save some
// lock time by accepting that possibility.
if (scanTask is null)
{
scanTask = Task.Run(
() =>
{
foreach (string file in _workspaceService.EnumeratePSFiles())
{
if (_workspaceService.TryGetFile(file, out ScriptFile scriptFile))
{
scriptFile.References.EnsureInitialized();
}
}
},
CancellationToken.None);

// Ignore the analyzer yelling that we're not awaiting this task, we'll get there.
#pragma warning disable CS4014
Interlocked.CompareExchange(ref _workspaceScanCompleted, scanTask, null);
#pragma warning restore CS4014
}

// In the simple case where the task is already completed or the token we're given cannot
// be cancelled, do a simple await.
if (scanTask.IsCompleted || !cancellationToken.CanBeCanceled)
{
await scanTask.ConfigureAwait(false);
return;
}

// If it's not yet done and we can be cancelled, create a new task to represent the
// cancellation. That way we can exit a request that relies on the scan without
// having to actually stop the work (and then request it again in a few seconds).
//
// TODO: There's a new API in net6 that lets you await a task with a cancellation token.
// we should #if that in if feasible.
TaskCompletionSource<bool> cancelled = new();
cancellationToken.Register(() => cancelled.TrySetCanceled());
await Task.WhenAny(scanTask, cancelled.Task).ConfigureAwait(false);
}

/// <summary>
/// Gets a path from a dot-source symbol.
/// </summary>
Expand Down Expand Up @@ -673,6 +746,35 @@ public static FunctionDefinitionAst GetFunctionDefinitionAtLine(

internal void OnConfigurationUpdated(object _, LanguageServerSettings e)
{
if (e.AnalyzeOpenDocumentsOnly)
{
Task scanInProgress = _workspaceScanCompleted;
if (scanInProgress is not null)
{
// Wait until after the scan completes to close unopened files.
_ = scanInProgress.ContinueWith(_ => CloseUnopenedFiles(), TaskScheduler.Default);
}
else
{
CloseUnopenedFiles();
}

_workspaceScanCompleted = null;

void CloseUnopenedFiles()
{
foreach (ScriptFile scriptFile in _workspaceService.GetOpenedFiles())
{
if (scriptFile.IsOpen)
{
continue;
}

_workspaceService.CloseFile(scriptFile);
}
}
}

if (e.EnableReferencesCodeLens)
{
if (_codeLensProviders.ContainsKey(ReferencesCodeLensProvider.Id))
Expand Down
Loading