-
Notifications
You must be signed in to change notification settings - Fork 14.1k
[IR2Vec] Minor vocab changes and exposing weights #143200
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
This stack of pull requests is managed by Graphite. Learn more about stacking. |
@llvm/pr-subscribers-llvm-analysis @llvm/pr-subscribers-mlgo Author: S. VenkataKeerthy (svkeerthy) ChangesThis PR changes some asserts in Vocab to hard checks that emit error and exposes flags and constructor to help in unit tests. (Tracking issue - #141817) Full diff: https://github.com/llvm/llvm-project/pull/143200.diff 3 Files Affected:
diff --git a/llvm/include/llvm/Analysis/IR2Vec.h b/llvm/include/llvm/Analysis/IR2Vec.h
index 930b13f079796..1bd80ed65d434 100644
--- a/llvm/include/llvm/Analysis/IR2Vec.h
+++ b/llvm/include/llvm/Analysis/IR2Vec.h
@@ -31,7 +31,9 @@
#include "llvm/ADT/DenseMap.h"
#include "llvm/IR/PassManager.h"
+#include "llvm/Support/CommandLine.h"
#include "llvm/Support/ErrorOr.h"
+#include "llvm/Support/JSON.h"
#include <map>
namespace llvm {
@@ -43,6 +45,7 @@ class Function;
class Type;
class Value;
class raw_ostream;
+class LLVMContext;
/// IR2Vec computes two kinds of embeddings: Symbolic and Flow-aware.
/// Symbolic embeddings capture the "syntactic" and "statistical correlation"
@@ -53,6 +56,11 @@ class raw_ostream;
enum class IR2VecKind { Symbolic };
namespace ir2vec {
+
+LLVM_ABI extern cl::opt<float> OpcWeight;
+LLVM_ABI extern cl::opt<float> TypeWeight;
+LLVM_ABI extern cl::opt<float> ArgWeight;
+
/// Embedding is a ADT that wraps std::vector<double>. It provides
/// additional functionality for arithmetic and comparison operations.
struct Embedding : public std::vector<double> {
@@ -187,10 +195,12 @@ class IR2VecVocabResult {
class IR2VecVocabAnalysis : public AnalysisInfoMixin<IR2VecVocabAnalysis> {
ir2vec::Vocab Vocabulary;
Error readVocabulary();
+ void emitError(Error Err, LLVMContext &Ctx);
public:
static AnalysisKey Key;
IR2VecVocabAnalysis() = default;
+ explicit IR2VecVocabAnalysis(ir2vec::Vocab &&Vocab);
using Result = IR2VecVocabResult;
Result run(Module &M, ModuleAnalysisManager &MAM);
};
diff --git a/llvm/lib/Analysis/IR2Vec.cpp b/llvm/lib/Analysis/IR2Vec.cpp
index 8ee8e5b0ff74e..3333e751f10b8 100644
--- a/llvm/lib/Analysis/IR2Vec.cpp
+++ b/llvm/lib/Analysis/IR2Vec.cpp
@@ -16,13 +16,11 @@
#include "llvm/ADT/Statistic.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/PassManager.h"
-#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/Errc.h"
#include "llvm/Support/Error.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/Format.h"
-#include "llvm/Support/JSON.h"
#include "llvm/Support/MemoryBuffer.h"
using namespace llvm;
@@ -33,6 +31,8 @@ using namespace ir2vec;
STATISTIC(VocabMissCounter,
"Number of lookups to entites not present in the vocabulary");
+namespace llvm {
+namespace ir2vec {
static cl::OptionCategory IR2VecCategory("IR2Vec Options");
// FIXME: Use a default vocab when not specified
@@ -40,18 +40,20 @@ static cl::opt<std::string>
VocabFile("ir2vec-vocab-path", cl::Optional,
cl::desc("Path to the vocabulary file for IR2Vec"), cl::init(""),
cl::cat(IR2VecCategory));
-static cl::opt<float> OpcWeight("ir2vec-opc-weight", cl::Optional,
- cl::init(1.0),
- cl::desc("Weight for opcode embeddings"),
- cl::cat(IR2VecCategory));
-static cl::opt<float> TypeWeight("ir2vec-type-weight", cl::Optional,
- cl::init(0.5),
- cl::desc("Weight for type embeddings"),
- cl::cat(IR2VecCategory));
-static cl::opt<float> ArgWeight("ir2vec-arg-weight", cl::Optional,
- cl::init(0.2),
- cl::desc("Weight for argument embeddings"),
- cl::cat(IR2VecCategory));
+LLVM_ABI cl::opt<float> OpcWeight("ir2vec-opc-weight", cl::Optional,
+ cl::init(1.0),
+ cl::desc("Weight for opcode embeddings"),
+ cl::cat(IR2VecCategory));
+LLVM_ABI cl::opt<float> TypeWeight("ir2vec-type-weight", cl::Optional,
+ cl::init(0.5),
+ cl::desc("Weight for type embeddings"),
+ cl::cat(IR2VecCategory));
+LLVM_ABI cl::opt<float> ArgWeight("ir2vec-arg-weight", cl::Optional,
+ cl::init(0.2),
+ cl::desc("Weight for argument embeddings"),
+ cl::cat(IR2VecCategory));
+} // namespace ir2vec
+} // namespace llvm
AnalysisKey IR2VecVocabAnalysis::Key;
@@ -83,9 +85,10 @@ void Embedding::scaleAndAdd(const Embedding &Src, float Factor) {
bool Embedding::approximatelyEquals(const Embedding &RHS,
double Tolerance) const {
assert(this->size() == RHS.size() && "Vectors must have the same dimension");
- for (size_t i = 0; i < this->size(); ++i)
+ for (size_t i = 0; i < this->size(); ++i) {
if (std::abs((*this)[i] - RHS[i]) > Tolerance)
return false;
+ }
return true;
}
@@ -254,35 +257,57 @@ Error IR2VecVocabAnalysis::readVocabulary() {
return createStringError(errc::illegal_byte_sequence,
"Unable to parse the vocabulary");
}
- assert(Vocabulary.size() > 0 && "Vocabulary is empty");
+
+ if (Vocabulary.empty()) {
+ return createStringError(errc::illegal_byte_sequence,
+ "Vocabulary is empty");
+ }
unsigned Dim = Vocabulary.begin()->second.size();
- assert(Dim > 0 && "Dimension of vocabulary is zero");
- (void)Dim;
- assert(std::all_of(Vocabulary.begin(), Vocabulary.end(),
- [Dim](const std::pair<StringRef, Embedding> &Entry) {
- return Entry.second.size() == Dim;
- }) &&
- "All vectors in the vocabulary are not of the same dimension");
+ if (Dim == 0) {
+ return createStringError(errc::illegal_byte_sequence,
+ "Dimension of vocabulary is zero");
+ }
+
+ if (!std::all_of(Vocabulary.begin(), Vocabulary.end(),
+ [Dim](const std::pair<StringRef, Embedding> &Entry) {
+ return Entry.second.size() == Dim;
+ })) {
+ return createStringError(
+ errc::illegal_byte_sequence,
+ "All vectors in the vocabulary are not of the same dimension");
+ }
+
return Error::success();
}
+IR2VecVocabAnalysis::IR2VecVocabAnalysis(Vocab &&Vocabulary)
+ : Vocabulary(std::move(Vocabulary)) {}
+
+void IR2VecVocabAnalysis::emitError(Error Err, LLVMContext &Ctx) {
+ handleAllErrors(std::move(Err), [&](const ErrorInfoBase &EI) {
+ Ctx.emitError("Error reading vocabulary: " + EI.message());
+ });
+}
+
IR2VecVocabAnalysis::Result
IR2VecVocabAnalysis::run(Module &M, ModuleAnalysisManager &AM) {
auto Ctx = &M.getContext();
+ // FIXME: Scale the vocabulary once. This would avoid scaling per use later.
+ // If vocabulary is already populated by the constructor, use it.
+ if (!Vocabulary.empty())
+ return IR2VecVocabResult(std::move(Vocabulary));
+
+ // Otherwise, try to read from the vocabulary file.
if (VocabFile.empty()) {
// FIXME: Use default vocabulary
Ctx->emitError("IR2Vec vocabulary file path not specified");
return IR2VecVocabResult(); // Return invalid result
}
if (auto Err = readVocabulary()) {
- handleAllErrors(std::move(Err), [&](const ErrorInfoBase &EI) {
- Ctx->emitError("Error reading vocabulary: " + EI.message());
- });
+ emitError(std::move(Err), *Ctx);
return IR2VecVocabResult();
}
- // FIXME: Scale the vocabulary here once. This would avoid scaling per use
- // later.
return IR2VecVocabResult(std::move(Vocabulary));
}
diff --git a/llvm/unittests/Analysis/IR2VecTest.cpp b/llvm/unittests/Analysis/IR2VecTest.cpp
index 7259a8a2fe20a..7d5787243d12d 100644
--- a/llvm/unittests/Analysis/IR2VecTest.cpp
+++ b/llvm/unittests/Analysis/IR2VecTest.cpp
@@ -199,25 +199,30 @@ TEST(IR2VecTest, IR2VecVocabResultValidity) {
EXPECT_EQ(validResult.getDimension(), 2u);
}
-// Helper to create a minimal function and embedder for getter tests
-struct GetterTestEnv {
- Vocab V = {};
+// Fixture for IR2Vec tests requiring IR setup and weight management.
+class IR2VecTestFixture : public ::testing::Test {
+protected:
+ Vocab V;
LLVMContext Ctx;
- std::unique_ptr<Module> M = nullptr;
+ std::unique_ptr<Module> M;
Function *F = nullptr;
BasicBlock *BB = nullptr;
- Instruction *Add = nullptr;
- Instruction *Ret = nullptr;
- std::unique_ptr<Embedder> Emb = nullptr;
+ Instruction *AddInst = nullptr;
+ Instruction *RetInst = nullptr;
- GetterTestEnv() {
+ float OriginalOpcWeight = ::OpcWeight;
+ float OriginalTypeWeight = ::TypeWeight;
+ float OriginalArgWeight = ::ArgWeight;
+
+ void SetUp() override {
V = {{"add", {1.0, 2.0}},
{"integerTy", {0.5, 0.5}},
{"constant", {0.2, 0.3}},
{"variable", {0.0, 0.0}},
{"unknownTy", {0.0, 0.0}}};
- M = std::make_unique<Module>("M", Ctx);
+ // Setup IR
+ M = std::make_unique<Module>("TestM", Ctx);
FunctionType *FTy = FunctionType::get(
Type::getInt32Ty(Ctx), {Type::getInt32Ty(Ctx), Type::getInt32Ty(Ctx)},
false);
@@ -226,61 +231,82 @@ struct GetterTestEnv {
Argument *Arg = F->getArg(0);
llvm::Value *Const = ConstantInt::get(Type::getInt32Ty(Ctx), 42);
- Add = BinaryOperator::CreateAdd(Arg, Const, "add", BB);
- Ret = ReturnInst::Create(Ctx, Add, BB);
+ AddInst = BinaryOperator::CreateAdd(Arg, Const, "add", BB);
+ RetInst = ReturnInst::Create(Ctx, AddInst, BB);
+ }
+
+ void setWeights(float OpcWeight, float TypeWeight, float ArgWeight) {
+ ::OpcWeight = OpcWeight;
+ ::TypeWeight = TypeWeight;
+ ::ArgWeight = ArgWeight;
+ }
- auto Result = Embedder::create(IR2VecKind::Symbolic, *F, V);
- EXPECT_TRUE(static_cast<bool>(Result));
- Emb = std::move(*Result);
+ void TearDown() override {
+ // Restore original global weights
+ ::OpcWeight = OriginalOpcWeight;
+ ::TypeWeight = OriginalTypeWeight;
+ ::ArgWeight = OriginalArgWeight;
}
};
-TEST(IR2VecTest, GetInstVecMap) {
- GetterTestEnv Env;
- const auto &InstMap = Env.Emb->getInstVecMap();
+TEST_F(IR2VecTestFixture, GetInstVecMap) {
+ auto Result = Embedder::create(IR2VecKind::Symbolic, *F, V);
+ ASSERT_TRUE(static_cast<bool>(Result));
+ auto Emb = std::move(*Result);
+
+ const auto &InstMap = Emb->getInstVecMap();
EXPECT_EQ(InstMap.size(), 2u);
- EXPECT_TRUE(InstMap.count(Env.Add));
- EXPECT_TRUE(InstMap.count(Env.Ret));
+ EXPECT_TRUE(InstMap.count(AddInst));
+ EXPECT_TRUE(InstMap.count(RetInst));
- EXPECT_EQ(InstMap.at(Env.Add).size(), 2u);
- EXPECT_EQ(InstMap.at(Env.Ret).size(), 2u);
+ EXPECT_EQ(InstMap.at(AddInst).size(), 2u);
+ EXPECT_EQ(InstMap.at(RetInst).size(), 2u);
// Check values for add: {1.29, 2.31}
- EXPECT_THAT(InstMap.at(Env.Add),
+ EXPECT_THAT(InstMap.at(AddInst),
ElementsAre(DoubleNear(1.29, 1e-6), DoubleNear(2.31, 1e-6)));
// Check values for ret: {0.0, 0.}; Neither ret nor voidTy are present in
// vocab
- EXPECT_THAT(InstMap.at(Env.Ret), ElementsAre(0.0, 0.0));
+ EXPECT_THAT(InstMap.at(RetInst), ElementsAre(0.0, 0.0));
}
-TEST(IR2VecTest, GetBBVecMap) {
- GetterTestEnv Env;
- const auto &BBMap = Env.Emb->getBBVecMap();
+TEST_F(IR2VecTestFixture, GetBBVecMap) {
+ auto Result = Embedder::create(IR2VecKind::Symbolic, *F, V);
+ ASSERT_TRUE(static_cast<bool>(Result));
+ auto Emb = std::move(*Result);
+
+ const auto &BBMap = Emb->getBBVecMap();
EXPECT_EQ(BBMap.size(), 1u);
- EXPECT_TRUE(BBMap.count(Env.BB));
- EXPECT_EQ(BBMap.at(Env.BB).size(), 2u);
+ EXPECT_TRUE(BBMap.count(BB));
+ EXPECT_EQ(BBMap.at(BB).size(), 2u);
// BB vector should be sum of add and ret: {1.29, 2.31} + {0.0, 0.0} =
// {1.29, 2.31}
- EXPECT_THAT(BBMap.at(Env.BB),
+ EXPECT_THAT(BBMap.at(BB),
ElementsAre(DoubleNear(1.29, 1e-6), DoubleNear(2.31, 1e-6)));
}
-TEST(IR2VecTest, GetBBVector) {
- GetterTestEnv Env;
- const auto &BBVec = Env.Emb->getBBVector(*Env.BB);
+TEST_F(IR2VecTestFixture, GetBBVector) {
+ auto Result = Embedder::create(IR2VecKind::Symbolic, *F, V);
+ ASSERT_TRUE(static_cast<bool>(Result));
+ auto Emb = std::move(*Result);
+
+ const auto &BBVec = Emb->getBBVector(*BB);
EXPECT_EQ(BBVec.size(), 2u);
EXPECT_THAT(BBVec,
ElementsAre(DoubleNear(1.29, 1e-6), DoubleNear(2.31, 1e-6)));
}
-TEST(IR2VecTest, GetFunctionVector) {
- GetterTestEnv Env;
- const auto &FuncVec = Env.Emb->getFunctionVector();
+TEST_F(IR2VecTestFixture, GetFunctionVector) {
+ auto Result = Embedder::create(IR2VecKind::Symbolic, *F, V);
+ ASSERT_TRUE(static_cast<bool>(Result));
+ auto Emb = std::move(*Result);
+
+ const auto &FuncVec = Emb->getFunctionVector();
EXPECT_EQ(FuncVec.size(), 2u);
@@ -289,4 +315,45 @@ TEST(IR2VecTest, GetFunctionVector) {
ElementsAre(DoubleNear(1.29, 1e-6), DoubleNear(2.31, 1e-6)));
}
+TEST_F(IR2VecTestFixture, GetFunctionVectorWithCustomWeights) {
+ setWeights(1.0, 1.0, 1.0);
+
+ auto Result = Embedder::create(IR2VecKind::Symbolic, *F, V);
+ ASSERT_TRUE(static_cast<bool>(Result));
+ auto Emb = std::move(*Result);
+
+ const auto &FuncVec = Emb->getFunctionVector();
+
+ EXPECT_EQ(FuncVec.size(), 2u);
+
+ // Expected: 1*([1.0 2.0] + [0.0 0.0]) + 1*([0.5 0.5] + [0.0 0.0]) + 1*([0.2
+ // 0.3] + [0.0 0.0])
+ EXPECT_THAT(FuncVec,
+ ElementsAre(DoubleNear(1.7, 1e-6), DoubleNear(2.8, 1e-6)));
+}
+
+TEST(IR2VecTest, IR2VecVocabAnalysisWithPrepopulatedVocab) {
+ Vocab InitialVocab = {{"key1", {1.1, 2.2}}, {"key2", {3.3, 4.4}}};
+ Vocab ExpectedVocab = InitialVocab;
+ unsigned ExpectedDim = InitialVocab.begin()->second.size();
+
+ IR2VecVocabAnalysis VocabAnalysis(std::move(InitialVocab));
+
+ LLVMContext TestCtx;
+ Module TestMod("TestModuleForVocabAnalysis", TestCtx);
+ ModuleAnalysisManager MAM;
+ IR2VecVocabResult Result = VocabAnalysis.run(TestMod, MAM);
+
+ EXPECT_TRUE(Result.isValid());
+ ASSERT_FALSE(Result.getVocabulary().empty());
+ EXPECT_EQ(Result.getDimension(), ExpectedDim);
+
+ const auto &ResultVocab = Result.getVocabulary();
+ EXPECT_EQ(ResultVocab.size(), ExpectedVocab.size());
+ for (const auto &pair : ExpectedVocab) {
+ EXPECT_TRUE(ResultVocab.count(pair.first));
+ EXPECT_THAT(ResultVocab.at(pair.first), ElementsAreArray(pair.second));
+ }
+}
+
} // end anonymous namespace
|
d62a38b
to
750cb2f
Compare
8e5c13a
to
602c2d3
Compare
c1842ec
to
7f2012c
Compare
602c2d3
to
6817aa9
Compare
7f2012c
to
d3468ab
Compare
6817aa9
to
9c05884
Compare
d3468ab
to
96e4a8b
Compare
9c05884
to
dfe59f2
Compare
llvm/include/llvm/Analysis/IR2Vec.h
Outdated
@@ -53,7 +56,12 @@ class raw_ostream; | |||
enum class IR2VecKind { Symbolic }; | |||
|
|||
namespace ir2vec { | |||
/// Embedding is a datavtype that wraps std::vector<double>. It provides | |||
|
|||
LLVM_ABI extern cl::opt<float> OpcWeight; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do these need an LLVM_ABI
tag?
If this is actually needed, this seems like something that should be split out of this patch.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
#136623 adds LLVM_ABI tags to some globals and externs in Analysis. Removing it for now. Can be added later if necessary.
dfe59f2
to
7a33ef8
Compare
d639ce4
to
c200f67
Compare
c200f67
to
8685c74
Compare
8685c74
to
5a7dd50
Compare
Merge activity
|
This PR changes some asserts in Vocab to hard checks that emit error and exposes flags and constructor to help in unit tests. (Tracking issue - llvm#141817)
This PR changes some asserts in Vocab to hard checks that emit error and exposes flags and constructor to help in unit tests.
(Tracking issue - #141817)