Skip to content

[AutoDiff] Relax @differentiable requirement for protocol witnesses. #29771

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Feb 23, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions include/swift/AST/Attr.h
Original file line number Diff line number Diff line change
Expand Up @@ -1672,6 +1672,13 @@ class DifferentiableAttr final
/// attribute's where clause requirements. This is set only if the attribute
/// has a where clause.
GenericSignature DerivativeGenericSignature;
/// The source location of the implicitly inherited protocol requirement
/// `@differentiable` attribute. Used for diagnostics, not serialized.
///
/// This is set during conformance type-checking, only for implicit
/// `@differentiable` attributes created for non-public protocol witnesses of
/// protocol requirements with `@differentiable` attributes.
SourceLoc ImplicitlyInheritedDifferentiableAttrLocation;
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: it's possible to generalize this "diagnostic source location" for other kinds of implicit @differentiable attributes too, not just "ones that are implicitly inherited from protocol requirements".

Currently, there are only three kinds of implicit @differentiable attributes:

  1. This case: @differentiable attributes inherited from protocol requirements.
  2. @differentiable attributes synthesized from others on the same declaration with superset parameter indices.
    • @differentiable(wrt: (x, y)) -> @differentiable(wrt: x)
  3. @differentiable attributes synthesized on stored properties of structs/classes that derive a conformance to Differentiable.
enum class ImplicitKind {
  InheritedFromProtocolRequirement, // this case
  SynthesizedSubsetParametersAttribute, // triggered during conformance type-checking
  SynthesizedForStoredProperty, // triggered during `Differentiable` derived conformances
};

I think the first case benefits the most from a "secondary diagnostic source location". Generalization could be done later.


explicit DifferentiableAttr(bool implicit, SourceLoc atLoc,
SourceRange baseRange, bool linear,
Expand Down Expand Up @@ -1757,6 +1764,14 @@ class DifferentiableAttr final
FuncDecl *getVJPFunction() const { return VJPFunction; }
void setVJPFunction(FuncDecl *decl);

SourceLoc getImplicitlyInheritedDifferentiableAttrLocation() const {
return ImplicitlyInheritedDifferentiableAttrLocation;
}
void getImplicitlyInheritedDifferentiableAttrLocation(SourceLoc loc) {
assert(isImplicit());
ImplicitlyInheritedDifferentiableAttrLocation = loc;
}

/// Get the derivative generic environment for the given `@differentiable`
/// attribute and original function.
GenericEnvironment *
Expand Down
3 changes: 3 additions & 0 deletions include/swift/AST/DiagnosticsSIL.def
Original file line number Diff line number Diff line change
Expand Up @@ -540,6 +540,9 @@ NOTE(autodiff_cannot_differentiate_through_multiple_results,none,
"cannot differentiate through multiple results", ())
NOTE(autodiff_class_member_not_supported,none,
"differentiating class members is not yet supported", ())
NOTE(autodiff_implicitly_inherited_differentiable_attr_here,none,
"differentiability required by the corresponding protocol requirement "
"here", ())
// TODO(TF-642): Remove when `partial_apply` works with `@differentiable`
// functions.
NOTE(autodiff_cannot_param_subset_thunk_partially_applied_orig_fn,none,
Expand Down
6 changes: 6 additions & 0 deletions include/swift/AST/DiagnosticsSema.def
Original file line number Diff line number Diff line change
Expand Up @@ -3023,6 +3023,12 @@ ERROR(overriding_decl_missing_differentiable_attr,none,
"overriding declaration is missing attribute '%0'", (StringRef))
NOTE(protocol_witness_missing_differentiable_attr,none,
"candidate is missing attribute '%0'", (StringRef))
NOTE(protocol_witness_missing_differentiable_attr_nonpublic_other_file,none,
"non-public %1 %2 must have explicit '%0' attribute to satisfy "
"requirement %3 %4 (in protocol %6) because it is declared in a different "
"file than the conformance of %5 to %6",
(StringRef, DescriptiveDeclKind, DeclName, DescriptiveDeclKind, DeclName,
Type, Type))

// @derivative
ERROR(derivative_attr_expected_result_tuple,none,
Expand Down
41 changes: 31 additions & 10 deletions include/swift/SILOptimizer/Utils/Differentiation/ADContext.h
Original file line number Diff line number Diff line change
Expand Up @@ -311,20 +311,41 @@ ADContext::emitNondifferentiabilityError(SourceLoc loc,
return diagnose(loc, diag, std::forward<U>(args)...);
}

// For `SILDifferentiabilityWitness`es, try to find an AST function
// declaration and `@differentiable` attribute. If they are found, emit an
// error on the `@differentiable` attribute; otherwise, emit an error on the
// SIL function. Emit a note at the non-differentiable operation.
// For differentiability witnesses: try to find a `@differentiable` or
// `@derivative` attribute. If an attribute is found, emit an error on it;
// otherwise, emit an error on the original function.
case DifferentiationInvoker::Kind::SILDifferentiabilityWitnessInvoker: {
auto *witness = invoker.getSILDifferentiabilityWitnessInvoker();
auto *original = witness->getOriginalFunction();
if (auto *diffAttr = witness->getAttribute()) {
diagnose(diffAttr->getLocation(),
// If the witness has an associated attribute, emit an error at its
// location.
if (auto *attr = witness->getAttribute()) {
diagnose(attr->getLocation(),
diag::autodiff_function_not_differentiable_error)
.highlight(diffAttr->getRangeWithAt());
diagnose(original->getLocation().getSourceLoc(),
diag::autodiff_when_differentiating_function_definition);
} else {
.highlight(attr->getRangeWithAt());
// Emit informative note.
bool emittedNote = false;
// If the witness comes from an implicit `@differentiable` attribute
// inherited from a protocol requirement's `@differentiable` attribute,
// emit a note on the inherited attribute.
if (auto *diffAttr = dyn_cast<DifferentiableAttr>(attr)) {
auto inheritedAttrLoc =
diffAttr->getImplicitlyInheritedDifferentiableAttrLocation();
if (inheritedAttrLoc.isValid()) {
diagnose(inheritedAttrLoc,
diag::autodiff_implicitly_inherited_differentiable_attr_here)
.highlight(inheritedAttrLoc);
emittedNote = true;
}
}
// Otherwise, emit a note on the original function.
if (!emittedNote) {
diagnose(original->getLocation().getSourceLoc(),
diag::autodiff_when_differentiating_function_definition);
}
}
// Otherwise, emit an error on the original function.
else {
diagnose(original->getLocation().getSourceLoc(),
diag::autodiff_function_not_differentiable_error);
}
Expand Down
91 changes: 73 additions & 18 deletions lib/Sema/TypeCheckProtocol.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,8 @@ static ValueDecl *getStandinForAccessor(AbstractStorageDecl *witness,
/// witness.
/// - If requirement's `@differentiable` attributes are met, or if `result` is
/// not viable, returns `result`.
/// - Otherwise, returns a `DifferentiableConflict` `RequirementMatch`.
/// - Otherwise, returns a "missing `@differentiable` attribute"
/// `RequirementMatch`.
// Note: the `result` argument is only necessary for using
// `RequirementMatch::WitnessSubstitutions`.
static RequirementMatch
Expand Down Expand Up @@ -384,15 +385,50 @@ matchWitnessDifferentiableAttr(DeclContext *dc, ValueDecl *req,
}
if (!foundExactConfig) {
bool success = false;
if (supersetConfig) {
// If the witness has a "superset" derivative configuration, create an
// implicit `@differentiable` attribute with the exact requirement
// `@differentiable` attribute parameter indices.
// If no exact witness derivative configuration was found, check
// conditions for creating an implicit witness `@differentiable` attribute
// with the exact derivative configuration:
// - If the witness has a "superset" derivative configuration.
// - If the witness is less than public and is declared in the same file
// as the conformance.
// - `@differentiable` attributes are really only significant for public
// declarations: it improves usability to not require explicit
// `@differentiable` attributes for less-visible declarations.
bool createImplicitWitnessAttribute =
supersetConfig || witness->getFormalAccess() < AccessLevel::Public;
// If the witness has less-than-public visibility and is declared in a
// different file than the conformance, produce an error.
if (!supersetConfig && witness->getFormalAccess() < AccessLevel::Public &&
dc->getModuleScopeContext() !=
witness->getDeclContext()->getModuleScopeContext()) {
// FIXME(TF-1014): `@differentiable` attribute diagnostic does not
// appear if associated type inference is involved.
if (auto *vdWitness = dyn_cast<VarDecl>(witness)) {
return RequirementMatch(
getStandinForAccessor(vdWitness, AccessorKind::Get),
MatchKind::MissingDifferentiableAttr, reqDiffAttr);
} else {
return RequirementMatch(witness, MatchKind::MissingDifferentiableAttr,
reqDiffAttr);
}
}
if (createImplicitWitnessAttribute) {
auto derivativeGenSig = witnessAFD->getGenericSignature();
if (supersetConfig)
derivativeGenSig = supersetConfig->derivativeGenericSignature;
// Use source location of the witness declaration as the source location
// of the implicit `@differentiable` attribute.
auto *newAttr = DifferentiableAttr::create(
witnessAFD, /*implicit*/ true, reqDiffAttr->AtLoc,
reqDiffAttr->getRange(), reqDiffAttr->isLinear(),
reqDiffAttr->getParameterIndices(), /*jvp*/ None,
/*vjp*/ None, supersetConfig->derivativeGenericSignature);
witnessAFD, /*implicit*/ true, witness->getLoc(), witness->getLoc(),
reqDiffAttr->isLinear(), reqDiffAttr->getParameterIndices(),
/*jvp*/ None, /*vjp*/ None, derivativeGenSig);
// If the implicit attribute is inherited from a protocol requirement's
// attribute, store the protocol requirement attribute's location for
// use in diagnostics.
if (witness->getFormalAccess() < AccessLevel::Public) {
newAttr->getImplicitlyInheritedDifferentiableAttrLocation(
reqDiffAttr->getLocation());
}
auto insertion = ctx.DifferentiableAttrs.try_emplace(
{witnessAFD, newAttr->getParameterIndices()}, newAttr);
// Valid `@differentiable` attributes are uniqued by original function
Expand All @@ -418,9 +454,9 @@ matchWitnessDifferentiableAttr(DeclContext *dc, ValueDecl *req,
if (auto *vdWitness = dyn_cast<VarDecl>(witness)) {
return RequirementMatch(
getStandinForAccessor(vdWitness, AccessorKind::Get),
MatchKind::DifferentiableConflict, reqDiffAttr);
MatchKind::MissingDifferentiableAttr, reqDiffAttr);
} else {
return RequirementMatch(witness, MatchKind::DifferentiableConflict,
return RequirementMatch(witness, MatchKind::MissingDifferentiableAttr,
reqDiffAttr);
}
}
Expand Down Expand Up @@ -2318,14 +2354,15 @@ diagnoseMatch(ModuleDecl *module, NormalProtocolConformance *conformance,
case MatchKind::NonObjC:
diags.diagnose(match.Witness, diag::protocol_witness_not_objc);
break;
case MatchKind::DifferentiableConflict: {
case MatchKind::MissingDifferentiableAttr: {
auto witness = match.Witness;
// Emit a note and fix-it showing the missing requirement `@differentiable`
// attribute.
auto *reqAttr = cast<DifferentiableAttr>(match.UnmetAttribute);
assert(reqAttr);
// Omit printing `wrt:` clause if attribute's differentiability
// parameters match inferred differentiability parameters.
auto *original = cast<AbstractFunctionDecl>(match.Witness);
auto *original = cast<AbstractFunctionDecl>(witness);
auto *whereClauseGenEnv =
reqAttr->getDerivativeGenericEnvironment(original);
auto *inferredParameters = TypeChecker::inferDifferentiabilityParameters(
Expand All @@ -2336,11 +2373,29 @@ diagnoseMatch(ModuleDecl *module, NormalProtocolConformance *conformance,
llvm::raw_string_ostream os(reqDiffAttrString);
reqAttr->print(os, req, omitWrtClause, /*omitDerivativeFunctions*/ true);
os.flush();
diags
.diagnose(match.Witness,
diag::protocol_witness_missing_differentiable_attr,
reqDiffAttrString)
.fixItInsert(match.Witness->getStartLoc(), reqDiffAttrString + ' ');
// If the witness has less-than-public visibility and is declared in a
// different file than the conformance, emit a specialized diagnostic.
if (witness->getFormalAccess() < AccessLevel::Public &&
conformance->getDeclContext()->getModuleScopeContext() !=
witness->getDeclContext()->getModuleScopeContext()) {
diags
.diagnose(
witness,
diag::
protocol_witness_missing_differentiable_attr_nonpublic_other_file,
reqDiffAttrString, witness->getDescriptiveKind(),
witness->getFullName(), req->getDescriptiveKind(),
req->getFullName(), conformance->getType(),
conformance->getProtocol()->getDeclaredInterfaceType())
.fixItInsert(match.Witness->getStartLoc(), reqDiffAttrString + ' ');
}
// Otherwise, emit a general "missing attribute" diagnostic.
else {
diags
.diagnose(witness, diag::protocol_witness_missing_differentiable_attr,
reqDiffAttrString)
.fixItInsert(witness->getStartLoc(), reqDiffAttrString + ' ');
}
break;
}
}
Expand Down
15 changes: 8 additions & 7 deletions lib/Sema/TypeCheckProtocol.h
Original file line number Diff line number Diff line change
Expand Up @@ -210,9 +210,8 @@ enum class MatchKind : uint8_t {
/// The witness is explicitly @nonobjc but the requirement is @objc.
NonObjC,

/// The witness does not have a `@differentiable` attribute satisfying one
/// from the requirement.
DifferentiableConflict,
/// The witness is missing a `@differentiable` attribute from the requirement.
MissingDifferentiableAttr,
};

/// Describes the kind of optional adjustment performed when
Expand Down Expand Up @@ -363,7 +362,7 @@ struct RequirementMatch {
: Witness(witness), Kind(kind), WitnessType(), UnmetAttribute(attr),
ReqEnv(None) {
assert(!hasWitnessType() && "Should have witness type");
assert(UnmetAttribute);
assert(hasUnmetAttribute() && "Should have unmet attribute");
}

RequirementMatch(ValueDecl *witness, MatchKind kind,
Expand Down Expand Up @@ -438,7 +437,7 @@ struct RequirementMatch {
case MatchKind::RethrowsConflict:
case MatchKind::ThrowsConflict:
case MatchKind::NonObjC:
case MatchKind::DifferentiableConflict:
case MatchKind::MissingDifferentiableAttr:
return false;
}

Expand Down Expand Up @@ -468,7 +467,7 @@ struct RequirementMatch {
case MatchKind::RethrowsConflict:
case MatchKind::ThrowsConflict:
case MatchKind::NonObjC:
case MatchKind::DifferentiableConflict:
case MatchKind::MissingDifferentiableAttr:
return false;
}

Expand All @@ -479,7 +478,9 @@ struct RequirementMatch {
bool hasRequirement() { return Kind == MatchKind::MissingRequirement; }

/// Determine whether this requirement match has an unmet attribute.
bool hasUnmetAttribute() { return Kind == MatchKind::DifferentiableConflict; }
bool hasUnmetAttribute() {
return Kind == MatchKind::MissingDifferentiableAttr;
}

swift::Witness getWitness(ASTContext &ctx) const;
};
Expand Down
Loading