diff --git a/CMakeLists.txt b/CMakeLists.txt index f2d81d8c6..52aa752e0 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -295,6 +295,10 @@ add_executable(souper-check tools/souper-check.cpp ) +add_executable(generalize + tools/generalize.cpp +) + add_executable(souper-interpret tools/souper-interpret.cpp ) @@ -362,7 +366,7 @@ configure_file( ) foreach(target souper internal-solver-test lexer-test parser-test souper-check count-insts - souper2llvm souper-interpret + souper2llvm souper-interpret generalize souperExtractor souperInfer souperInst souperKVStore souperParser souperSMTLIB2 souperTool souperPass souperPassProfileAll kleeExpr souperCodegen) @@ -400,6 +404,7 @@ target_link_libraries(internal-solver-test souperSMTLIB2) target_link_libraries(lexer-test souperParser) target_link_libraries(parser-test souperParser) target_link_libraries(souper-check souperTool souperExtractor souperKVStore souperSMTLIB2 souperParser ${HIREDIS_LIBRARY} ${ALIVE_LIBRARY} ${Z3_LIBRARY}) +target_link_libraries(generalize souperTool souperExtractor souperKVStore souperSMTLIB2 souperParser ${HIREDIS_LIBRARY} ${ALIVE_LIBRARY} ${Z3_LIBRARY}) target_link_libraries(souper-interpret souperTool souperExtractor souperKVStore souperSMTLIB2 souperParser ${HIREDIS_LIBRARY} ${ALIVE_LIBRARY} ${Z3_LIBRARY}) target_link_libraries(clang-souper souperClangTool souperExtractor souperKVStore souperParser souperSMTLIB2 souperTool kleeExpr ${CLANG_LIBS} ${LLVM_LIBS} ${LLVM_LDFLAGS} ${HIREDIS_LIBRARY} ${ALIVE_LIBRARY} ${Z3_LIBRARY}) target_link_libraries(count-insts souperParser) diff --git a/include/souper/Extractor/Solver.h b/include/souper/Extractor/Solver.h index a9648cfe1..11488e186 100644 --- a/include/souper/Extractor/Solver.h +++ b/include/souper/Extractor/Solver.h @@ -44,6 +44,12 @@ class Solver { InstMapping Mapping, bool &IsValid, std::vector> *Model) = 0; + virtual std::error_code + isSatisfiable(llvm::StringRef Query, bool &Result, + unsigned NumModels, + std::vector *Models, + unsigned Timeout = 0) = 0; + virtual std::string getName() = 0; virtual @@ -90,8 +96,8 @@ class Solver { virtual std::error_code abstractPrecondition(const BlockPCs &BPCs, const std::vector &PCs, - InstMapping &Mapping, InstContext &IC, - bool &FoundWeakest) = 0; + InstMapping &Mapping, InstContext &IC, bool &FoundWeakest, + std::vector> &Results) = 0; }; std::unique_ptr createBaseSolver( diff --git a/include/souper/Infer/Preconditions.h b/include/souper/Infer/Preconditions.h index 7aabc735c..c5970d0b7 100644 --- a/include/souper/Infer/Preconditions.h +++ b/include/souper/Infer/Preconditions.h @@ -9,7 +9,7 @@ class SMTLIBSolver; class Solver; std::vector> inferAbstractKBPreconditions(SynthesisContext &SC, Inst *RHS, - SMTLIBSolver *SMTSolver, Solver *S, bool &FoundWeakest); + Solver *S, bool &FoundWeakest); } #endif // SOUPER_PRECONDITIONS_H diff --git a/lib/Extractor/Solver.cpp b/lib/Extractor/Solver.cpp index d4d940799..2fd9b64d5 100644 --- a/lib/Extractor/Solver.cpp +++ b/lib/Extractor/Solver.cpp @@ -265,39 +265,12 @@ class BaseSolver : public Solver { std::error_code abstractPrecondition(const BlockPCs &BPCs, const std::vector &PCs, - InstMapping &Mapping, InstContext &IC, - bool &FoundWeakest) override { + InstMapping &Mapping, InstContext &IC, bool &FoundWeakest, + std::vector> &Results) override { SynthesisContext SC{IC, SMTSolver.get(), Mapping.LHS, /*LHSUB*/nullptr, PCs, BPCs, /*CheckAllGuesses=*/false, Timeout}; - std::vector> Results = - inferAbstractKBPreconditions(SC, Mapping.RHS, SMTSolver.get(), this, FoundWeakest); - - ReplacementContext RC; - auto LHSStr = RC.printInst(Mapping.LHS, llvm::outs(), true); - llvm::outs() << "infer " << LHSStr << "\n"; - auto RHSStr = RC.printInst(Mapping.RHS, llvm::outs(), true); - llvm::outs() << "result " << RHSStr << "\n"; - for (size_t i = 0; i < Results.size(); ++i) { - for (auto It = Results[i].begin(); It != Results[i].end(); ++It) { - auto &&P = *It; - std::string dummy; - llvm::raw_string_ostream str(dummy); - auto VarStr = RC.printInst(P.first, str, false); - llvm::outs() << VarStr << " -> " << Inst::getKnownBitsString(P.second.Zero, P.second.One); - - auto Next = It; - Next++; - if (Next != Results[i].end()) { - llvm::outs() << " (and) "; - } - } - if (i == Results.size() - 1) { - llvm::outs() << "\n"; - } else { - llvm::outs() << "\n(or)\n"; - } - } + Results = inferAbstractKBPreconditions(SC, Mapping.RHS, this, FoundWeakest); return {}; } @@ -461,6 +434,13 @@ class BaseSolver : public Solver { return EC; } + std::error_code isSatisfiable(llvm::StringRef Query, bool &Result, + unsigned NumModels, + std::vector *Models, + unsigned Timeout = 0) override { + return SMTSolver->isSatisfiable(Query, Result, NumModels, Models, Timeout); + } + std::error_code isValid(InstContext &IC, const BlockPCs &BPCs, const std::vector &PCs, InstMapping Mapping, bool &IsValid, @@ -717,6 +697,13 @@ class MemCachingSolver : public Solver { } } + std::error_code isSatisfiable(llvm::StringRef Query, bool &Result, + unsigned NumModels, + std::vector *Models, + unsigned Timeout = 0) override { + return UnderlyingSolver->isSatisfiable(Query, Result, NumModels, Models, Timeout); + } + std::string getName() override { return UnderlyingSolver->getName() + " + internal cache"; } @@ -745,9 +732,9 @@ class MemCachingSolver : public Solver { std::error_code abstractPrecondition(const BlockPCs &BPCs, const std::vector &PCs, - InstMapping &Mapping, InstContext &IC, - bool &FoundWeakest) override { - return UnderlyingSolver->abstractPrecondition(BPCs, PCs, Mapping, IC, FoundWeakest); + InstMapping &Mapping, InstContext &IC, bool &FoundWeakest, + std::vector> &Results) override { + return UnderlyingSolver->abstractPrecondition(BPCs, PCs, Mapping, IC, FoundWeakest, Results); } std::error_code knownBits(const BlockPCs &BPCs, @@ -847,6 +834,13 @@ class ExternalCachingSolver : public Solver { return UnderlyingSolver->constantRange(BPCs, PCs, LHS, IC); } + std::error_code isSatisfiable(llvm::StringRef Query, bool &Result, + unsigned NumModels, + std::vector *Models, + unsigned Timeout = 0) override { + return UnderlyingSolver->isSatisfiable(Query, Result, NumModels, Models, Timeout); + } + std::error_code isValid(InstContext &IC, const BlockPCs &BPCs, const std::vector &PCs, InstMapping Mapping, bool &IsValid, @@ -885,9 +879,9 @@ class ExternalCachingSolver : public Solver { std::error_code abstractPrecondition(const BlockPCs &BPCs, const std::vector &PCs, - InstMapping &Mapping, InstContext &IC, - bool &FoundWeakest) override { - return UnderlyingSolver->abstractPrecondition(BPCs, PCs, Mapping, IC, FoundWeakest); + InstMapping &Mapping, InstContext &IC, bool &FoundWeakest, + std::vector> &Results) override { + return UnderlyingSolver->abstractPrecondition(BPCs, PCs, Mapping, IC, FoundWeakest, Results); } std::error_code knownBits(const BlockPCs &BPCs, diff --git a/lib/Infer/Preconditions.cpp b/lib/Infer/Preconditions.cpp index 3df2bedc2..2fae19e1b 100644 --- a/lib/Infer/Preconditions.cpp +++ b/lib/Infer/Preconditions.cpp @@ -7,7 +7,7 @@ using llvm::APInt; namespace souper { std::vector> inferAbstractKBPreconditions(SynthesisContext &SC, Inst *RHS, - SMTLIBSolver *SMTSolver, Solver *S, bool &FoundWeakest) { + Solver *S, bool &FoundWeakest) { InstMapping Mapping(SC.LHS, RHS); bool Valid; if (DebugLevel >= 3) { @@ -20,7 +20,10 @@ std::vector> } std::vector PCCopy = SC.PCs; if (Valid) { - llvm::outs() << "Already valid.\n"; + FoundWeakest = true; + if (DebugLevel > 1) { + llvm::errs() << "Already valid.\n"; + } return {}; } @@ -97,7 +100,7 @@ std::vector> &ModelInsts, Precondition, true); - SMTSolver->isSatisfiable(Query, FoundWeakest, ModelInsts.size(), + S->isSatisfiable(Query, FoundWeakest, ModelInsts.size(), &ModelVals, SC.Timeout); std::map Known; diff --git a/test/Generalize/fixit.opt b/test/Generalize/fixit.opt new file mode 100644 index 000000000..782c10e56 --- /dev/null +++ b/test/Generalize/fixit.opt @@ -0,0 +1,42 @@ +; REQUIRES: solver, synthesis +; RUN: %generalize -fixit %s | %souper-check > %t +; RUN: %FileCheck %s < %t + +%x:i8 = var +%y:i8 = var +%z = add %x, %y +%t = add %z, 42 +%u = sub %t, %y +infer %u +%v = add %x, 42 +result %v +;CHECK: LGTM + +%x:i8 = var +%y:i8 = var +%t = add %x, 42 +%u = sub %t, %y +infer %u +%v = add %x, 42 +result %v +;CHECK: LGTM + +%x:i8 = var +%y:i8 = var +%t = and %x, 137 +%u = xor %t, %y +infer %u +%v = or %x, %y +result %v +;CHECK: LGTM +;CHECK-NEXT: LGTM + +%x:i8 = var +%y:i8 = var +%t = or %x, 42 +%u = and %t, %y +infer %u +%v = and %x, %y +result %v +;CHECK: LGTM +;CHECK-NEXT: LGTM diff --git a/test/Generalize/leaf.opt b/test/Generalize/leaf.opt new file mode 100644 index 000000000..c3ac27865 --- /dev/null +++ b/test/Generalize/leaf.opt @@ -0,0 +1,22 @@ +; REQUIRES: solver, synthesis +; RUN: %generalize -remove-leaf %s | %souper-check > %t +; RUN: %FileCheck %s < %t + +%x:i8 = var +%y:i8 = var +%masked = and %x, 3 +%and = and %masked, %y +%foo = lshr %and, 2 +infer %and +result 0:i8 +; CHECK: LGTM +; CHECK: LGTM + +%x:i8 = var +%y:i8 = var +%a = and %x, 15 +%b = and %y, 240 +%foo = or %a, %b +infer %foo +result 0:i8 +; CHECK: LGTM diff --git a/tools/generalize.cpp b/tools/generalize.cpp new file mode 100644 index 000000000..865125c6f --- /dev/null +++ b/tools/generalize.cpp @@ -0,0 +1,163 @@ +#include "llvm/Support/MemoryBuffer.h" +#include "llvm/Support/GraphWriter.h" +#include "llvm/Support/KnownBits.h" + +#include "souper/Infer/Preconditions.h" + +#include "souper/Inst/InstGraph.h" +#include "souper/Parser/Parser.h" +#include "souper/Tool/GetSolver.h" +#include "souper/Util/DfaUtils.h" + +using namespace llvm; +using namespace souper; + +unsigned DebugLevel; + +static cl::opt +DebugFlagParser("souper-debug-level", + cl::desc("Control the verbose level of debug output (default=1). " + "The larger the number is, the more fine-grained debug " + "information will be printed."), + cl::location(DebugLevel), cl::init(1)); + +static cl::opt +InputFilename(cl::Positional, cl::desc(""), + cl::init("-")); + +static llvm::cl::opt RemoveLeaf("remove-leaf", + llvm::cl::desc("Try to generalize a valid optimization by replacing" + "the use of a once-used variable with a new variable" + "(default=false)"), + llvm::cl::init(false)); + +static llvm::cl::opt FixIt("fixit", + llvm::cl::desc("Given an invalid optimization, generate a valid one." + "(default=false)"), + llvm::cl::init(false)); + +void Generalize(InstContext &IC, Solver *S, ParsedReplacement Input) { + bool FoundWP = false; + std::vector> Results; + S->abstractPrecondition(Input.BPCs, Input.PCs, Input.Mapping, IC, FoundWP, Results); + + if (FoundWP && Results.empty()) { + Input.print(llvm::outs(), true); + } else { + for (auto &&Result : Results) { // Each result is a disjunction + for (auto Pair: Result) { + Pair.first->KnownOnes = Pair.second.One; + Pair.first->KnownZeros = Pair.second.Zero; + } + Input.print(llvm::outs(), true); + } + } +} + +// TODO: Return modified instructions instead of just printing out +void RemoveLeafAndGeneralize(InstContext &IC, + Solver *S, ParsedReplacement Input) { + + if (DebugLevel > 1) { + llvm::errs() << "Attempting to generalize by removing leaf.\n"; + } + // TODO: Do not generalize by removing leaf if LHS has one inst. + + std::map> Uses; + + std::vector Stack{Input.Mapping.LHS, Input.Mapping.RHS}; + // TODO: Find uses in PCs/BPCs + + std::set Visited; + while (!Stack.empty()) { + auto Current = Stack.back(); + Stack.pop_back(); + Visited.insert(Current); + + for (auto Op : Current->Ops) { + if (Op->K == Inst::Var) { + Uses[Op].insert(Current); + // Intentionally skips root + } + if (Visited.find(Op) == Visited.end()) { + Stack.push_back(Op); + } + } + } + + // Find a variable with one use; + Inst *UsedOnce = nullptr; + for (auto P : Uses) { + if (P.second.size() == 1) { + UsedOnce = P.first; + break; + } + } + + if (!UsedOnce) { + llvm::outs() << "Failed. No var with one use."; + return; + } else { + Inst *User = *Uses[UsedOnce].begin(); + Inst *NewVar = IC.createVar(User->Width, "newvar"); + + std::map ICache; + ICache[User] = NewVar; + + std::map BCache; + std::map CMap; + + Input.Mapping.LHS = getInstCopy(Input.Mapping.LHS, IC, ICache, + BCache, &CMap, false); + + Input.Mapping.RHS = getInstCopy(Input.Mapping.RHS, IC, ICache, + BCache, &CMap, false); + + // TODO: Replace PCs/BPCs + } + + Generalize(IC, S, Input); +} + +int main(int argc, char **argv) { + cl::ParseCommandLineOptions(argc, argv); + KVStore *KV = 0; + + std::unique_ptr S = 0; + S = GetSolver(KV); + + auto MB = MemoryBuffer::getFileOrSTDIN(InputFilename); + if (!MB) { + llvm::errs() << MB.getError().message() << '\n'; + return 1; + } + + InstContext IC; + std::string ErrStr; + + auto &&Data = (*MB)->getMemBufferRef(); + auto Inputs = ParseReplacements(IC, Data.getBufferIdentifier(), + Data.getBuffer(), ErrStr); + + if (!ErrStr.empty()) { + llvm::errs() << ErrStr << '\n'; + return 1; + } + + // TODO: Write default action which chooses what to do based on input structure + + for (auto &&Input: Inputs) { + if (FixIt) { + // TODO: Verify that inputs are valid optimizations + Generalize(IC, S.get(), Input); + } + if (RemoveLeaf) { + RemoveLeafAndGeneralize(IC, S.get(), Input); + } + // if (EviscerateRoot) {...} + // if (SymbolizeConstant) {...} + // if (LiberateWidth) {...} + } + + return 0; +} diff --git a/tools/souper-check.cpp b/tools/souper-check.cpp index f45adffec..9f310d220 100644 --- a/tools/souper-check.cpp +++ b/tools/souper-check.cpp @@ -325,11 +325,36 @@ int SolveInst(const MemoryBufferRef &MB, Solver *S) { } } else if (InferAP) { bool FoundWeakest = false; - S->abstractPrecondition(Rep.BPCs, Rep.PCs, Rep.Mapping, IC, FoundWeakest); + std::vector> Results; + S->abstractPrecondition(Rep.BPCs, Rep.PCs, Rep.Mapping, IC, FoundWeakest, Results); if (!FoundWeakest) { llvm::outs() << "Failed to find WP.\n"; } - + ReplacementContext RC; + auto LHSStr = RC.printInst(Rep.Mapping.LHS, llvm::outs(), true); + llvm::outs() << "infer " << LHSStr << "\n"; + auto RHSStr = RC.printInst(Rep.Mapping.RHS, llvm::outs(), true); + llvm::outs() << "result " << RHSStr << "\n"; + for (size_t i = 0; i < Results.size(); ++i) { + for (auto It = Results[i].begin(); It != Results[i].end(); ++It) { + auto &&P = *It; + std::string dummy; + llvm::raw_string_ostream str(dummy); + auto VarStr = RC.printInst(P.first, str, false); + llvm::outs() << VarStr << " -> " << Inst::getKnownBitsString(P.second.Zero, P.second.One); + + auto Next = It; + Next++; + if (Next != Results[i].end()) { + llvm::outs() << " (and) "; + } + } + if (i == Results.size() - 1) { + llvm::outs() << "\n"; + } else { + llvm::outs() << "\n(or)\n"; + } + } } else { bool Valid; std::vector> Models;