Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion source/slang/slang-ast-stmt.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@ class UnparsedStmt : public Stmt
Scope* currentScope = nullptr;
Scope* outerScope = nullptr;
SourceLanguage sourceLanguage;
bool isInVariadicGenerics = false;
};

FIDDLE()
Expand Down
44 changes: 41 additions & 3 deletions source/slang/slang-check-conversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1191,22 +1191,60 @@ bool SemanticsVisitor::_coerce(
// then we should start by trying to resolve the ambiguous reference
// based on prioritization of the different candidates.
//
// TODO: A more powerful model would be to try to coerce each
// If `fromExpr` is overloaded, we will try to coerce each
// of the constituent overload candidates, filtering down to
// those that are coercible, and then disambiguating the result.
// Such an approach would let us disambiguate between overloaded
// Such an approach lets us disambiguate between overloaded
// symbols based on their type (e.g., by casting the name of
// an overloaded function to the type of the overload we mean
// to reference).
//
if (auto fromOverloadedExpr = as<OverloadedExpr>(fromExpr))
{
auto resolvedExpr =
maybeResolveOverloadedExpr(fromOverloadedExpr, LookupMask::Default, nullptr);
maybeResolveOverloadedExpr(fromOverloadedExpr, LookupMask::Default, toType, nullptr);

fromExpr = resolvedExpr;
fromType = resolvedExpr->type;
}
else if (auto overloadedExpr2 = as<OverloadedExpr2>(fromExpr))
{
ShortList<Expr*> coercibleCandidates;
for (auto candidate : overloadedExpr2->candidateExprs)
{
if (canCoerce(toType, candidate->type, candidate))
coercibleCandidates.add(candidate);
}
if (coercibleCandidates.getCount() == 1)
{
return _coerce(
site,
toType,
outToExpr,
coercibleCandidates[0]->type,
coercibleCandidates[0],
sink,
outCost);
}
if (sink)
{
auto firstCandidate = overloadedExpr2->candidateExprs.getCount() > 0
? overloadedExpr2->candidateExprs[0]
: nullptr;
if (auto declCandidate = as<DeclRefExpr>(firstCandidate))
{
sink->diagnose(
fromExpr->loc,
Diagnostics::ambiguousReference,
declCandidate->declRef);
}
else
{
sink->diagnose(fromExpr->loc, Diagnostics::ambiguousExpression);
}
}
return false;
}

// An important and easy case is when the "to" and "from" types are equal.
//
Expand Down
1 change: 0 additions & 1 deletion source/slang/slang-check-decl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10424,7 +10424,6 @@ Stmt* SemanticsVisitor::maybeParseStmt(Stmt* stmt, const SemanticsContext& conte
&subVisitor,
getShared()->getTranslationUnitRequest(),
unparsedStmt->sourceLanguage,
unparsedStmt->isInVariadicGenerics,
tokenList,
getShared()->getSink(),
unparsedStmt->currentScope,
Expand Down
30 changes: 23 additions & 7 deletions source/slang/slang-check-expr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1128,7 +1128,9 @@ LookupResult SemanticsVisitor::filterLookupResultByCheckedOptionalAndDiagnose(
return result;
}

LookupResult SemanticsVisitor::resolveOverloadedLookup(LookupResult const& inResult)
LookupResult SemanticsVisitor::resolveOverloadedLookup(
LookupResult const& inResult,
Type* targetType)
{
// If the result isn't actually overloaded, it is fine as-is
if (!inResult.isValid())
Expand All @@ -1140,6 +1142,15 @@ LookupResult SemanticsVisitor::resolveOverloadedLookup(LookupResult const& inRes
List<LookupResultItem> items;
for (auto item : inResult.items)
{
// First we check if the item is coercible to targetType.
// And skip if it doesn't.
if (targetType)
{
auto declType = GetTypeForDeclRef(item.declRef, SourceLoc());
if (!canCoerce(targetType, declType, nullptr, nullptr))
continue;
}

// For each item we consider adding, we will compare it
// to those items we've already added.
//
Expand Down Expand Up @@ -1232,6 +1243,7 @@ void SemanticsVisitor::diagnoseAmbiguousReference(Expr* expr)
Expr* SemanticsVisitor::_resolveOverloadedExprImpl(
OverloadedExpr* overloadedExpr,
LookupMask mask,
Type* targetType,
DiagnosticSink* diagSink)
{
auto lookupResult = overloadedExpr->lookupResult2;
Expand All @@ -1246,7 +1258,7 @@ Expr* SemanticsVisitor::_resolveOverloadedExprImpl(
lookupResult = refineLookup(lookupResult, mask);

// Try to filter out overload candidates based on which ones are "better" than one another.
lookupResult = resolveOverloadedLookup(lookupResult);
lookupResult = resolveOverloadedLookup(lookupResult, targetType);

if (!lookupResult.isValid())
{
Expand Down Expand Up @@ -1296,24 +1308,28 @@ Expr* SemanticsVisitor::_resolveOverloadedExprImpl(
Expr* SemanticsVisitor::maybeResolveOverloadedExpr(
Expr* expr,
LookupMask mask,
Type* targetType,
DiagnosticSink* diagSink)
{
if (IsErrorExpr(expr))
return expr;

if (auto overloadedExpr = as<OverloadedExpr>(expr))
{
return _resolveOverloadedExprImpl(overloadedExpr, mask, diagSink);
return _resolveOverloadedExprImpl(overloadedExpr, mask, targetType, diagSink);
}
else
{
return expr;
}
}

Expr* SemanticsVisitor::resolveOverloadedExpr(OverloadedExpr* overloadedExpr, LookupMask mask)
Expr* SemanticsVisitor::resolveOverloadedExpr(
OverloadedExpr* overloadedExpr,
Type* targetType,
LookupMask mask)
{
return _resolveOverloadedExprImpl(overloadedExpr, mask, getSink());
return _resolveOverloadedExprImpl(overloadedExpr, mask, targetType, getSink());
}

Type* SemanticsVisitor::tryGetDifferentialType(ASTBuilder* builder, Type* type)
Expand Down Expand Up @@ -1364,7 +1380,7 @@ Type* SemanticsVisitor::tryGetDifferentialType(ASTBuilder* builder, Type* type)
Slang::LookupMask::type,
Slang::LookupOptions::None);

diffTypeLookupResult = resolveOverloadedLookup(diffTypeLookupResult);
diffTypeLookupResult = resolveOverloadedLookup(diffTypeLookupResult, nullptr);

if (!diffTypeLookupResult.isValid())
{
Expand Down Expand Up @@ -4416,7 +4432,7 @@ Expr* SemanticsExprVisitor::visitEachExpr(EachExpr* expr)
{
goto error;
}
if (!declRefType->getDeclRef().as<GenericTypePackParamDecl>())
if (!declRefType->getDeclRef().as<GenericTypePackParamDecl>() && !as<TupleType>(baseType))
{
goto error;
}
Expand Down
28 changes: 25 additions & 3 deletions source/slang/slang-check-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -1405,7 +1405,11 @@ struct SemanticsVisitor : public SemanticsContext
DeclRef<Decl> resolveDeclRef(DeclRef<Decl> declRef);

/// Attempt to "resolve" an overloaded `LookupResult` to only include the "best" results
LookupResult resolveOverloadedLookup(LookupResult const& lookupResult);
LookupResult resolveOverloadedLookup(LookupResult const& lookupResult, Type* targetType);
inline LookupResult resolveOverloadedLookup(LookupResult const& lookupResult)
{
return resolveOverloadedLookup(lookupResult, nullptr);
}

/// Attempt to resolve `expr` into an expression that refers to a single declaration/value.
/// If `expr` isn't overloaded, then it will be returned as-is.
Expand All @@ -1417,19 +1421,33 @@ struct SemanticsVisitor : public SemanticsContext
/// appropriate "ambiguous reference" error will be reported, and an error expression will be
/// returned. Otherwise, the original expression is returned if resolution fails.
///
Expr* maybeResolveOverloadedExpr(Expr* expr, LookupMask mask, DiagnosticSink* diagSink);
Expr* maybeResolveOverloadedExpr(
Expr* expr,
LookupMask mask,
Type* targetType,
DiagnosticSink* diagSink);

inline Expr* maybeResolveOverloadedExpr(Expr* expr, LookupMask mask, DiagnosticSink* diagSink)
{
return maybeResolveOverloadedExpr(expr, mask, nullptr, diagSink);
}

/// Attempt to resolve `overloadedExpr` into an expression that refers to a single
/// declaration/value.
///
/// Equivalent to `maybeResolveOverloadedExpr` with `diagSink` bound to the sink for the
/// `SemanticsVisitor`.
Expr* resolveOverloadedExpr(OverloadedExpr* overloadedExpr, LookupMask mask);
Expr* resolveOverloadedExpr(OverloadedExpr* overloadedExpr, Type* targetType, LookupMask mask);
inline Expr* resolveOverloadedExpr(OverloadedExpr* overloadedExpr, LookupMask mask)
{
return resolveOverloadedExpr(overloadedExpr, nullptr, mask);
}

/// Worker reoutine for `maybeResolveOverloadedExpr` and `resolveOverloadedExpr`.
Expr* _resolveOverloadedExprImpl(
OverloadedExpr* overloadedExpr,
LookupMask mask,
Type* targetType,
DiagnosticSink* diagSink);

void diagnoseAmbiguousReference(
Expand Down Expand Up @@ -2840,6 +2858,10 @@ struct SemanticsVisitor : public SemanticsContext

void AddGenericOverloadCandidates(Expr* baseExpr, OverloadResolveContext& context);

// Given an argument list, expand all `expand` expressions, if the type/value pack being
// expanded is already specialized.
void maybeExpandArgList(List<Expr*>& args);

template<class T>
void trySetGenericToRayTracingWithParamAttribute(
LookupResultItem genericItem,
Expand Down
116 changes: 97 additions & 19 deletions source/slang/slang-check-overload.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2065,6 +2065,94 @@ void SemanticsVisitor::AddCtorOverloadCandidate(
AddOverloadCandidate(context, candidate, baseCost);
}

void SemanticsVisitor::maybeExpandArgList(List<Expr*>& args)
{
bool needExpansion = false;
for (auto expr : args)
{
while (auto paren = as<ParenExpr>(expr))
expr = paren->base;

if (auto expand = as<ExpandExpr>(expr))
{
auto exprType = expand->type.type;
if (auto typeType = as<TypeType>(exprType))
exprType = typeType->getType();
if (as<ConcreteTypePack>(exprType))
{
needExpansion = true;
}
}
}
// Fast path without creating list copies.
if (!needExpansion)
return;
List<Expr*> result;
for (auto expr : args)
{
while (auto paren = as<ParenExpr>(expr))
expr = paren->base;
auto processExpr = [&]()
{
auto expand = as<ExpandExpr>(expr);
if (!expand)
return false;
auto type = expand->type.type;
if (auto typeType = as<TypeType>(type))
{
auto typePack = as<ConcreteTypePack>(typeType->getType());
if (!typePack)
return false;
for (Index i = 0; i < typePack->getTypeCount(); i++)
{
auto expandArg = m_astBuilder->create<SharedTypeExpr>();
expandArg->loc = expr->loc;
expandArg->type = m_astBuilder->getTypeType(typePack->getElementType(i));
result.add(expandArg);
}
return true;
}
else if (auto typePack = as<ConcreteTypePack>(type))
{
auto localScope = getExprLocalScope();
SLANG_ASSERT(localScope);

VarDecl* varDecl = m_astBuilder->create<VarDecl>();
varDecl->parentDecl = nullptr;
if (m_outerScope && m_outerScope->containerDecl)
m_outerScope->containerDecl->addMember(varDecl);
addModifier(varDecl, m_astBuilder->create<LocalTempVarModifier>());
varDecl->checkState = DeclCheckState::DefinitionChecked;
varDecl->nameAndLoc.loc = expr->loc;
varDecl->initExpr = expr;
varDecl->type.type = expr->type.type;
LetExpr* letExpr = m_astBuilder->create<LetExpr>();
letExpr->decl = varDecl;
localScope->addBinding(letExpr);
auto varExpr = m_astBuilder->create<VarExpr>();
varExpr->declRef = varDecl;
varExpr->type = expr->type.type;
varExpr->type.isLeftValue = false;
for (Index i = 0; i < typePack->getTypeCount(); i++)
{
auto expandedArg = m_astBuilder->create<SwizzleExpr>();
expandedArg->base = varExpr;
expandedArg->type = typePack->getElementType(i);
expandedArg->type.isLeftValue = false;
expandedArg->elementIndices.add((uint32_t)i);
result.add(expandedArg);
}
return true;
}
return false;
};

if (!processExpr())
result.add(expr);
}
args.swapWith(result);
}

bool SemanticsVisitor::OverloadResolveContext::matchArgumentsToParams(
SemanticsVisitor* semantics,
const List<QualType>& params,
Expand Down Expand Up @@ -2102,7 +2190,8 @@ bool SemanticsVisitor::OverloadResolveContext::matchArgumentsToParams(
}

// Try to match the variadic part.
// Is the corresponding argument a expand expr? If so it will map 1:1 to the type pack param.
// Is the corresponding argument a expand expr? If so it will map 1:1 to the type pack
// param.
auto astBuilder = semantics->getASTBuilder();

if (remainingArgCount <= 0)
Expand Down Expand Up @@ -2655,22 +2744,8 @@ Expr* SemanticsVisitor::ResolveInvoke(InvokeExpr* expr)
{
return CreateErrorExpr(expr);
}
// If any of the arguments is an error, then we should bail out, to avoid
// cascading errors where we successfully pick an overload, but not the one
// the user meant.
for (auto arg : expr->arguments)
{
if (IsErrorExpr(arg))
return CreateErrorExpr(expr);

// If this argument is itself an overloaded value without a type
// then we can't sensibly continue
if (!arg->type && (as<OverloadedExpr>(arg) || as<OverloadedExpr2>(arg)))
{
getSink()->diagnose(expr->loc, Diagnostics::overloadedParameterToHigherOrderFunction);
return CreateErrorExpr(expr);
}
}
maybeExpandArgList(expr->arguments);

for (auto& arg : expr->arguments)
{
Expand Down Expand Up @@ -2700,7 +2775,8 @@ Expr* SemanticsVisitor::ResolveInvoke(InvokeExpr* expr)
if (typeCheckingCache->resolvedOperatorOverloadCache.tryGetValue(key, candidate))
{
// We should only use the cached candidate if it is persistent direct declref
// created from GlobalSession's ASTBuilder, or it is created in the current Linkage.
// created from GlobalSession's ASTBuilder, or it is created in the current
// Linkage.
if (candidate.cacheVersion == typeCheckingCache->version ||
findNextOuterGeneric(candidate.decl) == nullptr)
{
Expand Down Expand Up @@ -2910,8 +2986,8 @@ Expr* SemanticsVisitor::ResolveInvoke(InvokeExpr* expr)
}
}

// Now that we have resolved the overload candidate, we need to undo an `openExistential`
// operation that was applied to `out` arguments.
// Now that we have resolved the overload candidate, we need to undo an
// `openExistential` operation that was applied to `out` arguments.
//
auto funcType = context.bestCandidate->funcType;
ShortList<ParameterDirection> paramDirections;
Expand Down Expand Up @@ -3087,6 +3163,7 @@ Expr* SemanticsVisitor::checkGenericAppWithCheckedArgs(GenericAppExpr* genericAp

auto& baseExpr = genericAppExpr->functionExpr;
auto& args = genericAppExpr->arguments;
maybeExpandArgList(args);

// If there was an error in the base expression, or in any of
// the arguments, then just bail.
Expand Down Expand Up @@ -3138,6 +3215,7 @@ Expr* SemanticsVisitor::checkGenericAppWithCheckedArgs(GenericAppExpr* genericAp
// to complete all of them and create an overloaded expression as a result.

auto overloadedExpr = m_astBuilder->create<OverloadedExpr2>();
overloadedExpr->type = m_astBuilder->getOverloadedType();
overloadedExpr->base = context.baseExpr;
for (auto candidate : context.bestCandidates)
{
Expand Down
Loading
Loading