Skip to content

Commit e8056f5

Browse files
committed
Fixed support for coroutine templates and C++20s struct as NTTP.
1 parent cb44b7d commit e8056f5

6 files changed

+429
-11
lines changed

CodeGenerator.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -591,7 +591,7 @@ class CoroutinesCodeGenerator final : public CodeGenerator
591591
std::string mFSMName{};
592592
CoroutineASTData mASTData{};
593593
llvm::DenseMap<const Stmt*, bool> mBinaryExprs{};
594-
static inline llvm::DenseMap<const Expr*, std::string>
594+
static inline llvm::DenseMap<const Expr*, std::pair<const DeclRefExpr*, std::string>>
595595
mOpaqueValues{}; ///! Keeps track of the current set of opaque value
596596

597597
QualType GetFrameType() const { return QualType(mASTData.mFrameType->getTypeForDecl(), 0); }

CoroutinesCodeGenerator.cpp

+14-9
Original file line numberDiff line numberDiff line change
@@ -431,12 +431,12 @@ void CoroutinesCodeGenerator::InsertCoroutine(const FunctionDecl& fd, const Coro
431431
}
432432
}
433433

434-
auto& str = ofm.GetString();
434+
auto str = std::move(ofm.GetString());
435435
ReplaceAll(str, "<"sv, ""sv);
436436
ReplaceAll(str, ":"sv, ""sv);
437437
ReplaceAll(str, ">"sv, ""sv);
438-
ReplaceAll(str, ","sv, ""sv);
439-
ReplaceAll(str, " "sv, ""sv);
438+
439+
str = BuildTemplateParamObjectName(str);
440440

441441
if(fd.isOverloadedOperator()) {
442442
return StrCat(MakeLineColumnName(ctx.getSourceManager(), stmt->getBeginLoc(), "operator_"sv), str);
@@ -737,7 +737,6 @@ void CoroutinesCodeGenerator::InsertArg(const CoroutineBodyStmt* stmt)
737737
}
738738

739739
if(const auto* coReturnVoid = dyn_cast_or_null<CoreturnStmt>(stmt->getFallthroughHandler())) {
740-
coReturnVoid->dump();
741740
funcBodyStmts.Add(coReturnVoid);
742741
}
743742

@@ -817,10 +816,11 @@ void CoroutinesCodeGenerator::InsertArg(const CallExpr* stmt)
817816
}
818817
//-----------------------------------------------------------------------------
819818

820-
static std::optional<std::string> FindValue(llvm::DenseMap<const Expr*, std::string>& map, const Expr* key)
819+
static std::optional<std::string>
820+
FindValue(llvm::DenseMap<const Expr*, std::pair<const DeclRefExpr*, std::string>>& map, const Expr* key)
821821
{
822822
if(const auto& s = map.find(key); s != map.end()) {
823-
return s->second;
823+
return s->second.second;
824824
}
825825

826826
return {};
@@ -838,18 +838,23 @@ void CoroutinesCodeGenerator::InsertArg(const OpaqueValueExpr* stmt)
838838
// Needs to be internal because a user can create the same type and it gets put into the stack frame
839839
std::string name{BuildSuspendVarName(stmt)};
840840

841+
// In case of a coroutine-template the same suspension point can occur multiple times. But to know when to add
842+
// the _1 we must match the one from each instantiation. The DeclRefExpr is what distinguishes the same
843+
// OpaqueValueExpr between multiple instantiations.
844+
const auto* dref = FindDeclRef(sourceExpr);
845+
841846
// The initial_suspend and final_suspend expressions carry the same location info. If we hit such a case,
842847
// make up another name.
843848
// Below is a std::find_if. However, the same code looks unreadable with std::find_if
844-
for(const auto lookupName{StrCat(CORO_FRAME_ACCESS, name)}; const auto& [k, v] : mOpaqueValues) {
845-
if(v == lookupName) {
849+
for(const auto lookupName{StrCat(CORO_FRAME_ACCESS, name)}; const auto& [k, value] : mOpaqueValues) {
850+
if(auto [thisDeref, v] = value; (thisDeref == dref) and (v == lookupName)) {
846851
name += "_1"sv;
847852
break;
848853
}
849854
}
850855

851856
const auto accessName{StrCat(CORO_FRAME_ACCESS, name)};
852-
mOpaqueValues.insert(std::make_pair(sourceExpr, accessName));
857+
mOpaqueValues.insert(std::make_pair(sourceExpr, std::make_pair(dref, accessName)));
853858

854859
OutputFormatHelper ofm{};
855860
CoroutinesCodeGenerator codeGenerator{ofm, mPosBeforeFunc, mFSMName, mSuspendsCount, mASTData};

InsightsHelpers.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -1139,6 +1139,7 @@ std::string BuildTemplateParamObjectName(std::string name)
11391139
ReplaceAll(name, ","sv, "_"sv);
11401140
ReplaceAll(name, "."sv, "_"sv);
11411141
ReplaceAll(name, "+"sv, "_"sv);
1142+
ReplaceAll(name, "-"sv, "n"sv);
11421143

11431144
return name;
11441145
}

InsightsStrCat.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -60,11 +60,11 @@ inline std::string Normalize(const APValue& arg)
6060
::llvm::raw_string_ostream stream{str};
6161

6262
arg.getFloat().print(stream);
63+
str.pop_back();
6364

6465
if(std::string::npos == str.find('.')) {
6566
/* in case it is a number like 10.0 toString() seems to leave out the .0. However, as this distinguished
6667
* between an integer and a floating point literal we need that dot. */
67-
str.pop_back();
6868
str.append(".0");
6969
}
7070

+50
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
// cmdline:-std=c++20
2+
// cmdlineinsights:-edu-show-coroutine-transformation
3+
4+
#include <coroutine>
5+
#include <exception> // std::terminate
6+
#include <new>
7+
#include <utility>
8+
9+
#define INSIGHTS_USE_TEMPLATE 1
10+
11+
struct ClsAsNTTPToCoro
12+
{
13+
int x;
14+
double d;
15+
};
16+
17+
template <typename T> struct generator {
18+
struct promise_type {
19+
T current_value{};
20+
21+
std::suspend_always yield_value(T value) {
22+
current_value = value;
23+
return {};
24+
}
25+
std::suspend_always initial_suspend() { return {}; }
26+
std::suspend_always final_suspend() noexcept { return {}; }
27+
generator get_return_object() { return generator{this}; };
28+
void unhandled_exception() { std::terminate(); }
29+
void return_value(T v) { current_value = v; }
30+
};
31+
32+
generator(generator &&rhs) : p{std::exchange(rhs.p, nullptr)} {}
33+
~generator() { if (p) { p.destroy(); } }
34+
35+
private:
36+
explicit generator(promise_type* _p)
37+
: p{std::coroutine_handle<promise_type>::from_promise(*_p)} {}
38+
39+
std::coroutine_handle<promise_type> p;
40+
};
41+
42+
template <typename T, typename U, typename V, auto>
43+
generator<T> fun() {
44+
co_return 2;
45+
}
46+
47+
int main() {
48+
auto dbl = fun<int, char, unsigned int, 3.14>();
49+
auto stct = fun<int, char, unsigned int, ClsAsNTTPToCoro{4, -3.0}>();
50+
}

0 commit comments

Comments
 (0)