Skip to content

Commit c690ac8

Browse files
committed
[AutoDiff] Improve invalid stored property projection diagnostics.
Use TangentStoredPropertyRequest in differentiation transform. Improve non-differentiability diagnostics regarding invalid stored property projection instructions: `struct_extract`, `struct_element_addr`, `ref_element_addr`. Diagnose the following cases: - Original property's type does not conform to `Differentiable`. - Base type's `TangentVector` is not a struct. - Tangent property not found: base type's `TangentVector` does not have a stored property with the same name as the original property. - Tangent property's type is not equal to the original property's `TangentVector` type. - Tangent property is not a stored property. Resolves TF-969 and TF-970.
1 parent f163072 commit c690ac8

File tree

10 files changed

+536
-151
lines changed

10 files changed

+536
-151
lines changed

include/swift/AST/DiagnosticsSIL.def

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -504,9 +504,26 @@ NOTE(autodiff_loadable_value_addressonly_tangent_unsupported,none,
504504
"properties", (Type, Type))
505505
NOTE(autodiff_enums_unsupported,none,
506506
"differentiating enum values is not yet supported", ())
507+
NOTE(autodiff_stored_property_parent_not_differentiable,none,
508+
"cannot differentiate access to property '%0.%1' because '%0' does not "
509+
"conform to 'Differentiable'", (StringRef, StringRef))
510+
NOTE(autodiff_stored_property_not_differentiable,none,
511+
"cannot differentiate access to property '%0.%1' because property type %2 "
512+
"does not conform to 'Differentiable'", (StringRef, StringRef, Type))
513+
NOTE(autodiff_stored_property_tangent_not_struct,none,
514+
"cannot differentiate access to property '%0.%1' because "
515+
"'%0.TangentVector' is not a struct", (StringRef, StringRef))
507516
NOTE(autodiff_stored_property_no_corresponding_tangent,none,
508-
"property cannot be differentiated because '%0.TangentVector' does not "
509-
"have a member named '%1'", (StringRef, StringRef))
517+
"cannot differentiate access to property '%0.%1' because "
518+
"'%0.TangentVector' does not have a stored property named '%1'",
519+
(StringRef, StringRef))
520+
NOTE(autodiff_tangent_property_wrong_type,none,
521+
"cannot differentiate access to property '%0.%1' because "
522+
"'%0.TangentVector.%1' does not have expected type %2",
523+
(StringRef, StringRef, /*originalPropertyTanType*/ Type))
524+
NOTE(autodiff_tangent_property_not_stored,none,
525+
"cannot differentiate access to property '%0.%1' because "
526+
"'%0.TangentVector.%1' is not a stored property", (StringRef, StringRef))
510527
NOTE(autodiff_coroutines_not_supported,none,
511528
"differentiation of coroutine calls is not yet supported", ())
512529
NOTE(autodiff_cannot_differentiate_writes_to_global_variables,none,

include/swift/SIL/SILInstruction.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5758,6 +5758,13 @@ class FieldIndexCacheBase : public SingleValueInstruction {
57585758
return s;
57595759
}
57605760

5761+
static bool classof(const SILNode *node) {
5762+
SILNodeKind kind = node->getKind();
5763+
return kind == SILNodeKind::StructExtractInst ||
5764+
kind == SILNodeKind::StructElementAddrInst ||
5765+
kind == SILNodeKind::RefElementAddrInst;
5766+
}
5767+
57615768
private:
57625769
unsigned cacheFieldIndex();
57635770
};

include/swift/SILOptimizer/Differentiation/ADContext.h

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -253,11 +253,9 @@ ADContext::emitNondifferentiabilityError(SILValue value,
253253
getADDebugStream() << "For value:\n" << value;
254254
getADDebugStream() << "With invoker:\n" << invoker << '\n';
255255
});
256-
auto valueLoc = value.getLoc().getSourceLoc();
257256
// If instruction does not have a valid location, use the function location
258257
// as a fallback. Improves diagnostics in some cases.
259-
if (valueLoc.isInvalid())
260-
valueLoc = value->getFunction()->getLocation().getSourceLoc();
258+
auto valueLoc = getValidLocation(value).getSourceLoc();
261259
return emitNondifferentiabilityError(valueLoc, invoker, diag,
262260
std::forward<U>(args)...);
263261
}
@@ -272,12 +270,10 @@ ADContext::emitNondifferentiabilityError(SILInstruction *inst,
272270
getADDebugStream() << "For instruction:\n" << *inst;
273271
getADDebugStream() << "With invoker:\n" << invoker << '\n';
274272
});
275-
auto instLoc = inst->getLoc().getSourceLoc();
276273
// If instruction does not have a valid location, use the function location
277274
// as a fallback. Improves diagnostics for `ref_element_addr` generated in
278275
// synthesized stored property getters.
279-
if (instLoc.isInvalid())
280-
instLoc = inst->getFunction()->getLocation().getSourceLoc();
276+
auto instLoc = getValidLocation(inst).getSourceLoc();
281277
return emitNondifferentiabilityError(instLoc, invoker, diag,
282278
std::forward<U>(args)...);
283279
}

include/swift/SILOptimizer/Differentiation/Common.h

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,22 +17,27 @@
1717
#ifndef SWIFT_SILOPTIMIZER_UTILS_DIFFERENTIATION_COMMON_H
1818
#define SWIFT_SILOPTIMIZER_UTILS_DIFFERENTIATION_COMMON_H
1919

20+
#include "swift/AST/DiagnosticsSIL.h"
21+
#include "swift/AST/Expr.h"
2022
#include "swift/AST/SemanticAttrs.h"
2123
#include "swift/SIL/SILDifferentiabilityWitness.h"
2224
#include "swift/SIL/SILFunction.h"
2325
#include "swift/SIL/SILModule.h"
2426
#include "swift/SIL/TypeSubstCloner.h"
2527
#include "swift/SILOptimizer/Analysis/ArraySemantic.h"
2628
#include "swift/SILOptimizer/Analysis/DifferentiableActivityAnalysis.h"
29+
#include "swift/SILOptimizer/Differentiation/DifferentiationInvoker.h"
2730

2831
namespace swift {
2932

33+
namespace autodiff {
34+
35+
class ADContext;
36+
3037
//===----------------------------------------------------------------------===//
3138
// Helpers
3239
//===----------------------------------------------------------------------===//
3340

34-
namespace autodiff {
35-
3641
/// Prints an "[AD] " prefix to `llvm::dbgs()` and returns the debug stream.
3742
/// This is being used to print short debug messages within the AD pass.
3843
raw_ostream &getADDebugStream();
@@ -136,6 +141,34 @@ template <class Inst> Inst *peerThroughFunctionConversions(SILValue value) {
136141
return nullptr;
137142
}
138143

144+
//===----------------------------------------------------------------------===//
145+
// Diagnostic utilities
146+
//===----------------------------------------------------------------------===//
147+
148+
// Returns `v`'s location if it is valid. Otherwise, returns `v`'s function's
149+
// location as as a fallback. Used for diagnostics.
150+
SILLocation getValidLocation(SILValue v);
151+
152+
// Returns `inst`'s location if it is valid. Otherwise, returns `inst`'s
153+
// function's location as as a fallback. Used for diagnostics.
154+
SILLocation getValidLocation(SILInstruction *inst);
155+
156+
//===----------------------------------------------------------------------===//
157+
// Tangent property lookup utilities
158+
//===----------------------------------------------------------------------===//
159+
160+
/// Returns the tangent stored property of `originalField`. On error, emits
161+
/// diagnostic and returns nullptr.
162+
VarDecl *getTangentStoredProperty(ADContext &context, VarDecl *originalField,
163+
SILLocation loc,
164+
DifferentiationInvoker invoker);
165+
166+
/// Returns the tangent stored property of the original stored property
167+
/// referenced by `inst`. On error, emits diagnostic and returns nullptr.
168+
VarDecl *getTangentStoredProperty(ADContext &context,
169+
FieldIndexCacheBase *projectionInst,
170+
DifferentiationInvoker invoker);
171+
139172
//===----------------------------------------------------------------------===//
140173
// Code emission utilities
141174
//===----------------------------------------------------------------------===//

lib/SILOptimizer/Differentiation/Common.cpp

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
#define DEBUG_TYPE "differentiation"
1818

1919
#include "swift/SILOptimizer/Differentiation/Common.h"
20+
#include "swift/AST/TypeCheckRequests.h"
21+
#include "swift/SILOptimizer/Differentiation/ADContext.h"
2022

2123
namespace swift {
2224
namespace autodiff {
@@ -244,6 +246,97 @@ void collectMinimalIndicesForFunctionCall(
244246
}));
245247
}
246248

249+
//===----------------------------------------------------------------------===//
250+
// Diagnostic utilities
251+
//===----------------------------------------------------------------------===//
252+
253+
SILLocation getValidLocation(SILValue v) {
254+
auto loc = v.getLoc();
255+
if (loc.isNull() || loc.getSourceLoc().isInvalid())
256+
loc = v->getFunction()->getLocation();
257+
return loc;
258+
}
259+
260+
SILLocation getValidLocation(SILInstruction *inst) {
261+
auto loc = inst->getLoc();
262+
if (loc.isNull() || loc.getSourceLoc().isInvalid())
263+
loc = inst->getFunction()->getLocation();
264+
return loc;
265+
}
266+
267+
//===----------------------------------------------------------------------===//
268+
// Tangent property lookup utilities
269+
//===----------------------------------------------------------------------===//
270+
271+
VarDecl *getTangentStoredProperty(ADContext &context, VarDecl *originalField,
272+
SILLocation loc,
273+
DifferentiationInvoker invoker) {
274+
auto &astCtx = context.getASTContext();
275+
auto tanFieldInfo = evaluateOrDefault(
276+
astCtx.evaluator, TangentStoredPropertyRequest{originalField},
277+
TangentPropertyInfo(nullptr));
278+
// If no error, return the tangent property.
279+
if (tanFieldInfo)
280+
return tanFieldInfo.tangentProperty;
281+
// Otherwise, diagnose error and return nullptr.
282+
assert(tanFieldInfo.error);
283+
auto *parentDC = originalField->getDeclContext();
284+
assert(parentDC->isTypeContext());
285+
auto parentDeclName = parentDC->getSelfNominalTypeDecl()->getNameStr();
286+
auto fieldName = originalField->getNameStr();
287+
auto sourceLoc = loc.getSourceLoc();
288+
switch (tanFieldInfo.error->kind) {
289+
case TangentPropertyInfo::Error::Kind::NoDerivativeOriginalProperty:
290+
llvm_unreachable(
291+
"`@noDerivative` stored property accesses should not be "
292+
"differentiated; activity analysis should not mark as varied");
293+
case TangentPropertyInfo::Error::Kind::NominalParentNotDifferentiable:
294+
context.emitNondifferentiabilityError(
295+
sourceLoc, invoker,
296+
diag::autodiff_stored_property_parent_not_differentiable,
297+
parentDeclName, fieldName);
298+
break;
299+
case TangentPropertyInfo::Error::Kind::OriginalPropertyNotDifferentiable:
300+
context.emitNondifferentiabilityError(
301+
sourceLoc, invoker, diag::autodiff_stored_property_not_differentiable,
302+
parentDeclName, fieldName, originalField->getInterfaceType());
303+
break;
304+
case TangentPropertyInfo::Error::Kind::ParentTangentVectorNotStruct:
305+
context.emitNondifferentiabilityError(
306+
sourceLoc, invoker, diag::autodiff_stored_property_tangent_not_struct,
307+
parentDeclName, fieldName);
308+
break;
309+
case TangentPropertyInfo::Error::Kind::TangentPropertyNotFound:
310+
context.emitNondifferentiabilityError(
311+
sourceLoc, invoker,
312+
diag::autodiff_stored_property_no_corresponding_tangent, parentDeclName,
313+
fieldName);
314+
break;
315+
case TangentPropertyInfo::Error::Kind::TangentPropertyWrongType:
316+
context.emitNondifferentiabilityError(
317+
sourceLoc, invoker, diag::autodiff_tangent_property_wrong_type,
318+
parentDeclName, fieldName, tanFieldInfo.error->getType());
319+
break;
320+
case TangentPropertyInfo::Error::Kind::TangentPropertyNotStored:
321+
context.emitNondifferentiabilityError(
322+
sourceLoc, invoker, diag::autodiff_tangent_property_not_stored,
323+
parentDeclName, fieldName);
324+
break;
325+
}
326+
return nullptr;
327+
}
328+
329+
VarDecl *getTangentStoredProperty(ADContext &context,
330+
FieldIndexCacheBase *projectionInst,
331+
DifferentiationInvoker invoker) {
332+
assert(isa<StructExtractInst>(projectionInst) ||
333+
isa<StructElementAddrInst>(projectionInst) ||
334+
isa<RefElementAddrInst>(projectionInst));
335+
auto loc = getValidLocation(projectionInst);
336+
return getTangentStoredProperty(context, projectionInst->getField(), loc,
337+
invoker);
338+
}
339+
247340
//===----------------------------------------------------------------------===//
248341
// Code emission utilities
249342
//===----------------------------------------------------------------------===//

lib/SILOptimizer/Differentiation/JVPEmitter.cpp

Lines changed: 8 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -547,29 +547,12 @@ CLONE_AND_EMIT_TANGENT(StructExtract, sei) {
547547
assert(!sei->getField()->getAttrs().hasAttribute<NoDerivativeAttr>() &&
548548
"`struct_extract` with `@noDerivative` field should not be "
549549
"differentiated; activity analysis should not marked as varied.");
550-
551550
auto diffBuilder = getDifferentialBuilder();
552-
;
553-
auto tangentVectorTy = getRemappedTangentType(sei->getOperand()->getType());
554-
auto *tangentVectorDecl = tangentVectorTy.getStructOrBoundGenericStruct();
555-
556551
// Find the corresponding field in the tangent space.
557-
VarDecl *tanField = nullptr;
558-
// If the tangent space is the original struct, then field is the same.
559-
if (tangentVectorDecl == sei->getStructDecl())
560-
tanField = sei->getField();
561-
// Otherwise, look up the field by name.
562-
else {
563-
auto tanFieldLookup =
564-
tangentVectorDecl->lookupDirect(sei->getField()->getName());
565-
if (tanFieldLookup.empty()) {
566-
context.emitNondifferentiabilityError(
567-
sei, invoker, diag::autodiff_stored_property_no_corresponding_tangent,
568-
sei->getStructDecl()->getNameStr(), sei->getField()->getNameStr());
569-
errorOccurred = true;
570-
return;
571-
}
572-
tanField = cast<VarDecl>(tanFieldLookup.front());
552+
auto *tanField = getTangentStoredProperty(context, sei, invoker);
553+
if (!tanField) {
554+
errorOccurred = true;
555+
return;
573556
}
574557
// Emit tangent `struct_extract`.
575558
auto tanStruct =
@@ -590,32 +573,14 @@ CLONE_AND_EMIT_TANGENT(StructElementAddr, seai) {
590573
assert(!seai->getField()->getAttrs().hasAttribute<NoDerivativeAttr>() &&
591574
"`struct_element_addr` with `@noDerivative` field should not be "
592575
"differentiated; activity analysis should not marked as varied.");
593-
594576
auto diffBuilder = getDifferentialBuilder();
595577
auto *bb = seai->getParent();
596-
auto tangentVectorTy = getRemappedTangentType(seai->getOperand()->getType());
597-
auto *tangentVectorDecl = tangentVectorTy.getStructOrBoundGenericStruct();
598-
599578
// Find the corresponding field in the tangent space.
600-
VarDecl *tanField = nullptr;
601-
// If the tangent space is the original struct, then field is the same.
602-
if (tangentVectorDecl == seai->getStructDecl())
603-
tanField = seai->getField();
604-
// Otherwise, look up the field by name.
605-
else {
606-
auto tanFieldLookup =
607-
tangentVectorDecl->lookupDirect(seai->getField()->getName());
608-
if (tanFieldLookup.empty()) {
609-
context.emitNondifferentiabilityError(
610-
seai, invoker,
611-
diag::autodiff_stored_property_no_corresponding_tangent,
612-
seai->getStructDecl()->getNameStr(), seai->getField()->getNameStr());
613-
errorOccurred = true;
614-
return;
615-
}
616-
tanField = cast<VarDecl>(tanFieldLookup.front());
579+
auto *tanField = getTangentStoredProperty(context, seai, invoker);
580+
if (!tanField) {
581+
errorOccurred = true;
582+
return;
617583
}
618-
619584
// Emit tangent `struct_element_addr`.
620585
auto tanOperand = getTangentBuffer(bb, seai->getOperand());
621586
auto tangentInst =

0 commit comments

Comments
 (0)