Skip to content

Commit d62a38b

Browse files
committed
Vocab changes1
1 parent 8e5c13a commit d62a38b

File tree

3 files changed

+165
-63
lines changed

3 files changed

+165
-63
lines changed

llvm/include/llvm/Analysis/IR2Vec.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,9 @@
3131

3232
#include "llvm/ADT/DenseMap.h"
3333
#include "llvm/IR/PassManager.h"
34+
#include "llvm/Support/CommandLine.h"
3435
#include "llvm/Support/ErrorOr.h"
36+
#include "llvm/Support/JSON.h"
3537
#include <map>
3638

3739
namespace llvm {
@@ -43,6 +45,7 @@ class Function;
4345
class Type;
4446
class Value;
4547
class raw_ostream;
48+
class LLVMContext;
4649

4750
/// IR2Vec computes two kinds of embeddings: Symbolic and Flow-aware.
4851
/// Symbolic embeddings capture the "syntactic" and "statistical correlation"
@@ -53,6 +56,11 @@ class raw_ostream;
5356
enum class IR2VecKind { Symbolic };
5457

5558
namespace ir2vec {
59+
60+
LLVM_ABI extern cl::opt<float> OpcWeight;
61+
LLVM_ABI extern cl::opt<float> TypeWeight;
62+
LLVM_ABI extern cl::opt<float> ArgWeight;
63+
5664
/// Embedding is a ADT that wraps std::vector<double>. It provides
5765
/// additional functionality for arithmetic and comparison operations.
5866
struct Embedding : public std::vector<double> {
@@ -187,10 +195,12 @@ class IR2VecVocabResult {
187195
class IR2VecVocabAnalysis : public AnalysisInfoMixin<IR2VecVocabAnalysis> {
188196
ir2vec::Vocab Vocabulary;
189197
Error readVocabulary();
198+
void emitError(Error Err, LLVMContext &Ctx);
190199

191200
public:
192201
static AnalysisKey Key;
193202
IR2VecVocabAnalysis() = default;
203+
explicit IR2VecVocabAnalysis(ir2vec::Vocab &&Vocab);
194204
using Result = IR2VecVocabResult;
195205
Result run(Module &M, ModuleAnalysisManager &MAM);
196206
};

llvm/lib/Analysis/IR2Vec.cpp

Lines changed: 53 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,11 @@
1616
#include "llvm/ADT/Statistic.h"
1717
#include "llvm/IR/Module.h"
1818
#include "llvm/IR/PassManager.h"
19-
#include "llvm/Support/CommandLine.h"
2019
#include "llvm/Support/Debug.h"
2120
#include "llvm/Support/Errc.h"
2221
#include "llvm/Support/Error.h"
2322
#include "llvm/Support/ErrorHandling.h"
2423
#include "llvm/Support/Format.h"
25-
#include "llvm/Support/JSON.h"
2624
#include "llvm/Support/MemoryBuffer.h"
2725

2826
using namespace llvm;
@@ -33,25 +31,29 @@ using namespace ir2vec;
3331
STATISTIC(VocabMissCounter,
3432
"Number of lookups to entites not present in the vocabulary");
3533

34+
namespace llvm {
35+
namespace ir2vec {
3636
static cl::OptionCategory IR2VecCategory("IR2Vec Options");
3737

3838
// FIXME: Use a default vocab when not specified
3939
static cl::opt<std::string>
4040
VocabFile("ir2vec-vocab-path", cl::Optional,
4141
cl::desc("Path to the vocabulary file for IR2Vec"), cl::init(""),
4242
cl::cat(IR2VecCategory));
43-
static cl::opt<float> OpcWeight("ir2vec-opc-weight", cl::Optional,
44-
cl::init(1.0),
45-
cl::desc("Weight for opcode embeddings"),
46-
cl::cat(IR2VecCategory));
47-
static cl::opt<float> TypeWeight("ir2vec-type-weight", cl::Optional,
48-
cl::init(0.5),
49-
cl::desc("Weight for type embeddings"),
50-
cl::cat(IR2VecCategory));
51-
static cl::opt<float> ArgWeight("ir2vec-arg-weight", cl::Optional,
52-
cl::init(0.2),
53-
cl::desc("Weight for argument embeddings"),
54-
cl::cat(IR2VecCategory));
43+
LLVM_ABI cl::opt<float> OpcWeight("ir2vec-opc-weight", cl::Optional,
44+
cl::init(1.0),
45+
cl::desc("Weight for opcode embeddings"),
46+
cl::cat(IR2VecCategory));
47+
LLVM_ABI cl::opt<float> TypeWeight("ir2vec-type-weight", cl::Optional,
48+
cl::init(0.5),
49+
cl::desc("Weight for type embeddings"),
50+
cl::cat(IR2VecCategory));
51+
LLVM_ABI cl::opt<float> ArgWeight("ir2vec-arg-weight", cl::Optional,
52+
cl::init(0.2),
53+
cl::desc("Weight for argument embeddings"),
54+
cl::cat(IR2VecCategory));
55+
} // namespace ir2vec
56+
} // namespace llvm
5557

5658
AnalysisKey IR2VecVocabAnalysis::Key;
5759

@@ -83,9 +85,10 @@ void Embedding::scaleAndAdd(const Embedding &Src, float Factor) {
8385
bool Embedding::approximatelyEquals(const Embedding &RHS,
8486
double Tolerance) const {
8587
assert(this->size() == RHS.size() && "Vectors must have the same dimension");
86-
for (size_t i = 0; i < this->size(); ++i)
88+
for (size_t i = 0; i < this->size(); ++i) {
8789
if (std::abs((*this)[i] - RHS[i]) > Tolerance)
8890
return false;
91+
}
8992
return true;
9093
}
9194

@@ -254,35 +257,57 @@ Error IR2VecVocabAnalysis::readVocabulary() {
254257
return createStringError(errc::illegal_byte_sequence,
255258
"Unable to parse the vocabulary");
256259
}
257-
assert(Vocabulary.size() > 0 && "Vocabulary is empty");
260+
261+
if (Vocabulary.empty()) {
262+
return createStringError(errc::illegal_byte_sequence,
263+
"Vocabulary is empty");
264+
}
258265

259266
unsigned Dim = Vocabulary.begin()->second.size();
260-
assert(Dim > 0 && "Dimension of vocabulary is zero");
261-
(void)Dim;
262-
assert(std::all_of(Vocabulary.begin(), Vocabulary.end(),
263-
[Dim](const std::pair<StringRef, Embedding> &Entry) {
264-
return Entry.second.size() == Dim;
265-
}) &&
266-
"All vectors in the vocabulary are not of the same dimension");
267+
if (Dim == 0) {
268+
return createStringError(errc::illegal_byte_sequence,
269+
"Dimension of vocabulary is zero");
270+
}
271+
272+
if (!std::all_of(Vocabulary.begin(), Vocabulary.end(),
273+
[Dim](const std::pair<StringRef, Embedding> &Entry) {
274+
return Entry.second.size() == Dim;
275+
})) {
276+
return createStringError(
277+
errc::illegal_byte_sequence,
278+
"All vectors in the vocabulary are not of the same dimension");
279+
}
280+
267281
return Error::success();
268282
}
269283

284+
IR2VecVocabAnalysis::IR2VecVocabAnalysis(Vocab &&Vocabulary)
285+
: Vocabulary(std::move(Vocabulary)) {}
286+
287+
void IR2VecVocabAnalysis::emitError(Error Err, LLVMContext &Ctx) {
288+
handleAllErrors(std::move(Err), [&](const ErrorInfoBase &EI) {
289+
Ctx.emitError("Error reading vocabulary: " + EI.message());
290+
});
291+
}
292+
270293
IR2VecVocabAnalysis::Result
271294
IR2VecVocabAnalysis::run(Module &M, ModuleAnalysisManager &AM) {
272295
auto Ctx = &M.getContext();
296+
// FIXME: Scale the vocabulary once. This would avoid scaling per use later.
297+
// If vocabulary is already populated by the constructor, use it.
298+
if (!Vocabulary.empty())
299+
return IR2VecVocabResult(std::move(Vocabulary));
300+
301+
// Otherwise, try to read from the vocabulary file.
273302
if (VocabFile.empty()) {
274303
// FIXME: Use default vocabulary
275304
Ctx->emitError("IR2Vec vocabulary file path not specified");
276305
return IR2VecVocabResult(); // Return invalid result
277306
}
278307
if (auto Err = readVocabulary()) {
279-
handleAllErrors(std::move(Err), [&](const ErrorInfoBase &EI) {
280-
Ctx->emitError("Error reading vocabulary: " + EI.message());
281-
});
308+
emitError(std::move(Err), *Ctx);
282309
return IR2VecVocabResult();
283310
}
284-
// FIXME: Scale the vocabulary here once. This would avoid scaling per use
285-
// later.
286311
return IR2VecVocabResult(std::move(Vocabulary));
287312
}
288313

llvm/unittests/Analysis/IR2VecTest.cpp

Lines changed: 102 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -199,25 +199,30 @@ TEST(IR2VecTest, IR2VecVocabResultValidity) {
199199
EXPECT_EQ(validResult.getDimension(), 2u);
200200
}
201201

202-
// Helper to create a minimal function and embedder for getter tests
203-
struct GetterTestEnv {
204-
Vocab V = {};
202+
// Fixture for IR2Vec tests requiring IR setup and weight management.
203+
class IR2VecTestFixture : public ::testing::Test {
204+
protected:
205+
Vocab V;
205206
LLVMContext Ctx;
206-
std::unique_ptr<Module> M = nullptr;
207+
std::unique_ptr<Module> M;
207208
Function *F = nullptr;
208209
BasicBlock *BB = nullptr;
209-
Instruction *Add = nullptr;
210-
Instruction *Ret = nullptr;
211-
std::unique_ptr<Embedder> Emb = nullptr;
210+
Instruction *AddInst = nullptr;
211+
Instruction *RetInst = nullptr;
212212

213-
GetterTestEnv() {
213+
float OriginalOpcWeight = ::OpcWeight;
214+
float OriginalTypeWeight = ::TypeWeight;
215+
float OriginalArgWeight = ::ArgWeight;
216+
217+
void SetUp() override {
214218
V = {{"add", {1.0, 2.0}},
215219
{"integerTy", {0.5, 0.5}},
216220
{"constant", {0.2, 0.3}},
217221
{"variable", {0.0, 0.0}},
218222
{"unknownTy", {0.0, 0.0}}};
219223

220-
M = std::make_unique<Module>("M", Ctx);
224+
// Setup IR
225+
M = std::make_unique<Module>("TestM", Ctx);
221226
FunctionType *FTy = FunctionType::get(
222227
Type::getInt32Ty(Ctx), {Type::getInt32Ty(Ctx), Type::getInt32Ty(Ctx)},
223228
false);
@@ -226,61 +231,82 @@ struct GetterTestEnv {
226231
Argument *Arg = F->getArg(0);
227232
llvm::Value *Const = ConstantInt::get(Type::getInt32Ty(Ctx), 42);
228233

229-
Add = BinaryOperator::CreateAdd(Arg, Const, "add", BB);
230-
Ret = ReturnInst::Create(Ctx, Add, BB);
234+
AddInst = BinaryOperator::CreateAdd(Arg, Const, "add", BB);
235+
RetInst = ReturnInst::Create(Ctx, AddInst, BB);
236+
}
237+
238+
void setWeights(float OpcWeight, float TypeWeight, float ArgWeight) {
239+
::OpcWeight = OpcWeight;
240+
::TypeWeight = TypeWeight;
241+
::ArgWeight = ArgWeight;
242+
}
231243

232-
auto Result = Embedder::create(IR2VecKind::Symbolic, *F, V);
233-
EXPECT_TRUE(static_cast<bool>(Result));
234-
Emb = std::move(*Result);
244+
void TearDown() override {
245+
// Restore original global weights
246+
::OpcWeight = OriginalOpcWeight;
247+
::TypeWeight = OriginalTypeWeight;
248+
::ArgWeight = OriginalArgWeight;
235249
}
236250
};
237251

238-
TEST(IR2VecTest, GetInstVecMap) {
239-
GetterTestEnv Env;
240-
const auto &InstMap = Env.Emb->getInstVecMap();
252+
TEST_F(IR2VecTestFixture, GetInstVecMap) {
253+
auto Result = Embedder::create(IR2VecKind::Symbolic, *F, V);
254+
ASSERT_TRUE(static_cast<bool>(Result));
255+
auto Emb = std::move(*Result);
256+
257+
const auto &InstMap = Emb->getInstVecMap();
241258

242259
EXPECT_EQ(InstMap.size(), 2u);
243-
EXPECT_TRUE(InstMap.count(Env.Add));
244-
EXPECT_TRUE(InstMap.count(Env.Ret));
260+
EXPECT_TRUE(InstMap.count(AddInst));
261+
EXPECT_TRUE(InstMap.count(RetInst));
245262

246-
EXPECT_EQ(InstMap.at(Env.Add).size(), 2u);
247-
EXPECT_EQ(InstMap.at(Env.Ret).size(), 2u);
263+
EXPECT_EQ(InstMap.at(AddInst).size(), 2u);
264+
EXPECT_EQ(InstMap.at(RetInst).size(), 2u);
248265

249266
// Check values for add: {1.29, 2.31}
250-
EXPECT_THAT(InstMap.at(Env.Add),
267+
EXPECT_THAT(InstMap.at(AddInst),
251268
ElementsAre(DoubleNear(1.29, 1e-6), DoubleNear(2.31, 1e-6)));
252269

253270
// Check values for ret: {0.0, 0.}; Neither ret nor voidTy are present in
254271
// vocab
255-
EXPECT_THAT(InstMap.at(Env.Ret), ElementsAre(0.0, 0.0));
272+
EXPECT_THAT(InstMap.at(RetInst), ElementsAre(0.0, 0.0));
256273
}
257274

258-
TEST(IR2VecTest, GetBBVecMap) {
259-
GetterTestEnv Env;
260-
const auto &BBMap = Env.Emb->getBBVecMap();
275+
TEST_F(IR2VecTestFixture, GetBBVecMap) {
276+
auto Result = Embedder::create(IR2VecKind::Symbolic, *F, V);
277+
ASSERT_TRUE(static_cast<bool>(Result));
278+
auto Emb = std::move(*Result);
279+
280+
const auto &BBMap = Emb->getBBVecMap();
261281

262282
EXPECT_EQ(BBMap.size(), 1u);
263-
EXPECT_TRUE(BBMap.count(Env.BB));
264-
EXPECT_EQ(BBMap.at(Env.BB).size(), 2u);
283+
EXPECT_TRUE(BBMap.count(BB));
284+
EXPECT_EQ(BBMap.at(BB).size(), 2u);
265285

266286
// BB vector should be sum of add and ret: {1.29, 2.31} + {0.0, 0.0} =
267287
// {1.29, 2.31}
268-
EXPECT_THAT(BBMap.at(Env.BB),
288+
EXPECT_THAT(BBMap.at(BB),
269289
ElementsAre(DoubleNear(1.29, 1e-6), DoubleNear(2.31, 1e-6)));
270290
}
271291

272-
TEST(IR2VecTest, GetBBVector) {
273-
GetterTestEnv Env;
274-
const auto &BBVec = Env.Emb->getBBVector(*Env.BB);
292+
TEST_F(IR2VecTestFixture, GetBBVector) {
293+
auto Result = Embedder::create(IR2VecKind::Symbolic, *F, V);
294+
ASSERT_TRUE(static_cast<bool>(Result));
295+
auto Emb = std::move(*Result);
296+
297+
const auto &BBVec = Emb->getBBVector(*BB);
275298

276299
EXPECT_EQ(BBVec.size(), 2u);
277300
EXPECT_THAT(BBVec,
278301
ElementsAre(DoubleNear(1.29, 1e-6), DoubleNear(2.31, 1e-6)));
279302
}
280303

281-
TEST(IR2VecTest, GetFunctionVector) {
282-
GetterTestEnv Env;
283-
const auto &FuncVec = Env.Emb->getFunctionVector();
304+
TEST_F(IR2VecTestFixture, GetFunctionVector) {
305+
auto Result = Embedder::create(IR2VecKind::Symbolic, *F, V);
306+
ASSERT_TRUE(static_cast<bool>(Result));
307+
auto Emb = std::move(*Result);
308+
309+
const auto &FuncVec = Emb->getFunctionVector();
284310

285311
EXPECT_EQ(FuncVec.size(), 2u);
286312

@@ -289,4 +315,45 @@ TEST(IR2VecTest, GetFunctionVector) {
289315
ElementsAre(DoubleNear(1.29, 1e-6), DoubleNear(2.31, 1e-6)));
290316
}
291317

318+
TEST_F(IR2VecTestFixture, GetFunctionVectorWithCustomWeights) {
319+
setWeights(1.0, 1.0, 1.0);
320+
321+
auto Result = Embedder::create(IR2VecKind::Symbolic, *F, V);
322+
ASSERT_TRUE(static_cast<bool>(Result));
323+
auto Emb = std::move(*Result);
324+
325+
const auto &FuncVec = Emb->getFunctionVector();
326+
327+
EXPECT_EQ(FuncVec.size(), 2u);
328+
329+
// Expected: 1*([1.0 2.0] + [0.0 0.0]) + 1*([0.5 0.5] + [0.0 0.0]) + 1*([0.2
330+
// 0.3] + [0.0 0.0])
331+
EXPECT_THAT(FuncVec,
332+
ElementsAre(DoubleNear(1.7, 1e-6), DoubleNear(2.8, 1e-6)));
333+
}
334+
335+
TEST(IR2VecTest, IR2VecVocabAnalysisWithPrepopulatedVocab) {
336+
Vocab InitialVocab = {{"key1", {1.1, 2.2}}, {"key2", {3.3, 4.4}}};
337+
Vocab ExpectedVocab = InitialVocab;
338+
unsigned ExpectedDim = InitialVocab.begin()->second.size();
339+
340+
IR2VecVocabAnalysis VocabAnalysis(std::move(InitialVocab));
341+
342+
LLVMContext TestCtx;
343+
Module TestMod("TestModuleForVocabAnalysis", TestCtx);
344+
ModuleAnalysisManager MAM;
345+
IR2VecVocabResult Result = VocabAnalysis.run(TestMod, MAM);
346+
347+
EXPECT_TRUE(Result.isValid());
348+
ASSERT_FALSE(Result.getVocabulary().empty());
349+
EXPECT_EQ(Result.getDimension(), ExpectedDim);
350+
351+
const auto &ResultVocab = Result.getVocabulary();
352+
EXPECT_EQ(ResultVocab.size(), ExpectedVocab.size());
353+
for (const auto &pair : ExpectedVocab) {
354+
EXPECT_TRUE(ResultVocab.count(pair.first));
355+
EXPECT_THAT(ResultVocab.at(pair.first), ElementsAreArray(pair.second));
356+
}
357+
}
358+
292359
} // end anonymous namespace

0 commit comments

Comments
 (0)