Skip to content

Commit 8685c74

Browse files
committed
Vocab changes1
1 parent 32649e0 commit 8685c74

File tree

3 files changed

+164
-66
lines changed

3 files changed

+164
-66
lines changed

llvm/include/llvm/Analysis/IR2Vec.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,9 @@
3131

3232
#include "llvm/ADT/DenseMap.h"
3333
#include "llvm/IR/PassManager.h"
34+
#include "llvm/Support/CommandLine.h"
3435
#include "llvm/Support/ErrorOr.h"
36+
#include "llvm/Support/JSON.h"
3537
#include <map>
3638

3739
namespace llvm {
@@ -43,6 +45,7 @@ class Function;
4345
class Type;
4446
class Value;
4547
class raw_ostream;
48+
class LLVMContext;
4649

4750
/// IR2Vec computes two kinds of embeddings: Symbolic and Flow-aware.
4851
/// Symbolic embeddings capture the "syntactic" and "statistical correlation"
@@ -53,6 +56,11 @@ class raw_ostream;
5356
enum class IR2VecKind { Symbolic };
5457

5558
namespace ir2vec {
59+
60+
extern cl::opt<float> OpcWeight;
61+
extern cl::opt<float> TypeWeight;
62+
extern cl::opt<float> ArgWeight;
63+
5664
/// Embedding is a datatype that wraps std::vector<double>. It provides
5765
/// additional functionality for arithmetic and comparison operations.
5866
/// It is meant to be used *like* std::vector<double> but is more restrictive
@@ -226,10 +234,13 @@ class IR2VecVocabResult {
226234
class IR2VecVocabAnalysis : public AnalysisInfoMixin<IR2VecVocabAnalysis> {
227235
ir2vec::Vocab Vocabulary;
228236
Error readVocabulary();
237+
void emitError(Error Err, LLVMContext &Ctx);
229238

230239
public:
231240
static AnalysisKey Key;
232241
IR2VecVocabAnalysis() = default;
242+
explicit IR2VecVocabAnalysis(const ir2vec::Vocab &Vocab);
243+
explicit IR2VecVocabAnalysis(ir2vec::Vocab &&Vocab);
233244
using Result = IR2VecVocabResult;
234245
Result run(Module &M, ModuleAnalysisManager &MAM);
235246
};

llvm/lib/Analysis/IR2Vec.cpp

Lines changed: 51 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,11 @@
1616
#include "llvm/ADT/Statistic.h"
1717
#include "llvm/IR/Module.h"
1818
#include "llvm/IR/PassManager.h"
19-
#include "llvm/Support/CommandLine.h"
2019
#include "llvm/Support/Debug.h"
2120
#include "llvm/Support/Errc.h"
2221
#include "llvm/Support/Error.h"
2322
#include "llvm/Support/ErrorHandling.h"
2423
#include "llvm/Support/Format.h"
25-
#include "llvm/Support/JSON.h"
2624
#include "llvm/Support/MemoryBuffer.h"
2725

2826
using namespace llvm;
@@ -33,25 +31,26 @@ using namespace ir2vec;
3331
STATISTIC(VocabMissCounter,
3432
"Number of lookups to entites not present in the vocabulary");
3533

34+
namespace llvm {
35+
namespace ir2vec {
3636
static cl::OptionCategory IR2VecCategory("IR2Vec Options");
3737

3838
// FIXME: Use a default vocab when not specified
3939
static cl::opt<std::string>
4040
VocabFile("ir2vec-vocab-path", cl::Optional,
4141
cl::desc("Path to the vocabulary file for IR2Vec"), cl::init(""),
4242
cl::cat(IR2VecCategory));
43-
static cl::opt<float> OpcWeight("ir2vec-opc-weight", cl::Optional,
44-
cl::init(1.0),
45-
cl::desc("Weight for opcode embeddings"),
46-
cl::cat(IR2VecCategory));
47-
static cl::opt<float> TypeWeight("ir2vec-type-weight", cl::Optional,
48-
cl::init(0.5),
49-
cl::desc("Weight for type embeddings"),
50-
cl::cat(IR2VecCategory));
51-
static cl::opt<float> ArgWeight("ir2vec-arg-weight", cl::Optional,
52-
cl::init(0.2),
53-
cl::desc("Weight for argument embeddings"),
54-
cl::cat(IR2VecCategory));
43+
cl::opt<float> OpcWeight("ir2vec-opc-weight", cl::Optional, cl::init(1.0),
44+
cl::desc("Weight for opcode embeddings"),
45+
cl::cat(IR2VecCategory));
46+
cl::opt<float> TypeWeight("ir2vec-type-weight", cl::Optional, cl::init(0.5),
47+
cl::desc("Weight for type embeddings"),
48+
cl::cat(IR2VecCategory));
49+
cl::opt<float> ArgWeight("ir2vec-arg-weight", cl::Optional, cl::init(0.2),
50+
cl::desc("Weight for argument embeddings"),
51+
cl::cat(IR2VecCategory));
52+
} // namespace ir2vec
53+
} // namespace llvm
5554

5655
AnalysisKey IR2VecVocabAnalysis::Key;
5756

@@ -251,49 +250,70 @@ bool IR2VecVocabResult::invalidate(
251250
// by auto-generating a default vocabulary during the build time.
252251
Error IR2VecVocabAnalysis::readVocabulary() {
253252
auto BufOrError = MemoryBuffer::getFileOrSTDIN(VocabFile, /*IsText=*/true);
254-
if (!BufOrError) {
253+
if (!BufOrError)
255254
return createFileError(VocabFile, BufOrError.getError());
256-
}
255+
257256
auto Content = BufOrError.get()->getBuffer();
258257
json::Path::Root Path("");
259258
Expected<json::Value> ParsedVocabValue = json::parse(Content);
260259
if (!ParsedVocabValue)
261260
return ParsedVocabValue.takeError();
262261

263262
bool Res = json::fromJSON(*ParsedVocabValue, Vocabulary, Path);
264-
if (!Res) {
263+
if (!Res)
265264
return createStringError(errc::illegal_byte_sequence,
266265
"Unable to parse the vocabulary");
267-
}
268-
assert(Vocabulary.size() > 0 && "Vocabulary is empty");
266+
267+
if (Vocabulary.empty())
268+
return createStringError(errc::illegal_byte_sequence,
269+
"Vocabulary is empty");
269270

270271
unsigned Dim = Vocabulary.begin()->second.size();
271-
assert(Dim > 0 && "Dimension of vocabulary is zero");
272-
(void)Dim;
273-
assert(std::all_of(Vocabulary.begin(), Vocabulary.end(),
274-
[Dim](const std::pair<StringRef, Embedding> &Entry) {
275-
return Entry.second.size() == Dim;
276-
}) &&
277-
"All vectors in the vocabulary are not of the same dimension");
272+
if (Dim == 0)
273+
return createStringError(errc::illegal_byte_sequence,
274+
"Dimension of vocabulary is zero");
275+
276+
if (!std::all_of(Vocabulary.begin(), Vocabulary.end(),
277+
[Dim](const std::pair<StringRef, Embedding> &Entry) {
278+
return Entry.second.size() == Dim;
279+
}))
280+
return createStringError(
281+
errc::illegal_byte_sequence,
282+
"All vectors in the vocabulary are not of the same dimension");
283+
278284
return Error::success();
279285
}
280286

287+
IR2VecVocabAnalysis::IR2VecVocabAnalysis(const Vocab &Vocabulary)
288+
: Vocabulary(Vocabulary) {}
289+
290+
IR2VecVocabAnalysis::IR2VecVocabAnalysis(Vocab &&Vocabulary)
291+
: Vocabulary(std::move(Vocabulary)) {}
292+
293+
void IR2VecVocabAnalysis::emitError(Error Err, LLVMContext &Ctx) {
294+
handleAllErrors(std::move(Err), [&](const ErrorInfoBase &EI) {
295+
Ctx.emitError("Error reading vocabulary: " + EI.message());
296+
});
297+
}
298+
281299
IR2VecVocabAnalysis::Result
282300
IR2VecVocabAnalysis::run(Module &M, ModuleAnalysisManager &AM) {
283301
auto Ctx = &M.getContext();
302+
// FIXME: Scale the vocabulary once. This would avoid scaling per use later.
303+
// If vocabulary is already populated by the constructor, use it.
304+
if (!Vocabulary.empty())
305+
return IR2VecVocabResult(std::move(Vocabulary));
306+
307+
// Otherwise, try to read from the vocabulary file.
284308
if (VocabFile.empty()) {
285309
// FIXME: Use default vocabulary
286310
Ctx->emitError("IR2Vec vocabulary file path not specified");
287311
return IR2VecVocabResult(); // Return invalid result
288312
}
289313
if (auto Err = readVocabulary()) {
290-
handleAllErrors(std::move(Err), [&](const ErrorInfoBase &EI) {
291-
Ctx->emitError("Error reading vocabulary: " + EI.message());
292-
});
314+
emitError(std::move(Err), *Ctx);
293315
return IR2VecVocabResult();
294316
}
295-
// FIXME: Scale the vocabulary here once. This would avoid scaling per use
296-
// later.
297317
return IR2VecVocabResult(std::move(Vocabulary));
298318
}
299319

llvm/unittests/Analysis/IR2VecTest.cpp

Lines changed: 102 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -281,25 +281,30 @@ TEST(IR2VecTest, IR2VecVocabResultValidity) {
281281
EXPECT_EQ(validResult.getDimension(), 2u);
282282
}
283283

284-
// Helper to create a minimal function and embedder for getter tests
285-
struct GetterTestEnv {
286-
Vocab V = {};
284+
// Fixture for IR2Vec tests requiring IR setup and weight management.
285+
class IR2VecTestFixture : public ::testing::Test {
286+
protected:
287+
Vocab V;
287288
LLVMContext Ctx;
288-
std::unique_ptr<Module> M = nullptr;
289+
std::unique_ptr<Module> M;
289290
Function *F = nullptr;
290291
BasicBlock *BB = nullptr;
291-
Instruction *Add = nullptr;
292-
Instruction *Ret = nullptr;
293-
std::unique_ptr<Embedder> Emb = nullptr;
292+
Instruction *AddInst = nullptr;
293+
Instruction *RetInst = nullptr;
294294

295-
GetterTestEnv() {
295+
float OriginalOpcWeight = ::OpcWeight;
296+
float OriginalTypeWeight = ::TypeWeight;
297+
float OriginalArgWeight = ::ArgWeight;
298+
299+
void SetUp() override {
296300
V = {{"add", {1.0, 2.0}},
297301
{"integerTy", {0.5, 0.5}},
298302
{"constant", {0.2, 0.3}},
299303
{"variable", {0.0, 0.0}},
300304
{"unknownTy", {0.0, 0.0}}};
301305

302-
M = std::make_unique<Module>("M", Ctx);
306+
// Setup IR
307+
M = std::make_unique<Module>("TestM", Ctx);
303308
FunctionType *FTy = FunctionType::get(
304309
Type::getInt32Ty(Ctx), {Type::getInt32Ty(Ctx), Type::getInt32Ty(Ctx)},
305310
false);
@@ -308,61 +313,82 @@ struct GetterTestEnv {
308313
Argument *Arg = F->getArg(0);
309314
llvm::Value *Const = ConstantInt::get(Type::getInt32Ty(Ctx), 42);
310315

311-
Add = BinaryOperator::CreateAdd(Arg, Const, "add", BB);
312-
Ret = ReturnInst::Create(Ctx, Add, BB);
316+
AddInst = BinaryOperator::CreateAdd(Arg, Const, "add", BB);
317+
RetInst = ReturnInst::Create(Ctx, AddInst, BB);
318+
}
319+
320+
void setWeights(float OpcWeight, float TypeWeight, float ArgWeight) {
321+
::OpcWeight = OpcWeight;
322+
::TypeWeight = TypeWeight;
323+
::ArgWeight = ArgWeight;
324+
}
313325

314-
auto Result = Embedder::create(IR2VecKind::Symbolic, *F, V);
315-
EXPECT_TRUE(static_cast<bool>(Result));
316-
Emb = std::move(*Result);
326+
void TearDown() override {
327+
// Restore original global weights
328+
::OpcWeight = OriginalOpcWeight;
329+
::TypeWeight = OriginalTypeWeight;
330+
::ArgWeight = OriginalArgWeight;
317331
}
318332
};
319333

320-
TEST(IR2VecTest, GetInstVecMap) {
321-
GetterTestEnv Env;
322-
const auto &InstMap = Env.Emb->getInstVecMap();
334+
TEST_F(IR2VecTestFixture, GetInstVecMap) {
335+
auto Result = Embedder::create(IR2VecKind::Symbolic, *F, V);
336+
ASSERT_TRUE(static_cast<bool>(Result));
337+
auto Emb = std::move(*Result);
338+
339+
const auto &InstMap = Emb->getInstVecMap();
323340

324341
EXPECT_EQ(InstMap.size(), 2u);
325-
EXPECT_TRUE(InstMap.count(Env.Add));
326-
EXPECT_TRUE(InstMap.count(Env.Ret));
342+
EXPECT_TRUE(InstMap.count(AddInst));
343+
EXPECT_TRUE(InstMap.count(RetInst));
327344

328-
EXPECT_EQ(InstMap.at(Env.Add).size(), 2u);
329-
EXPECT_EQ(InstMap.at(Env.Ret).size(), 2u);
345+
EXPECT_EQ(InstMap.at(AddInst).size(), 2u);
346+
EXPECT_EQ(InstMap.at(RetInst).size(), 2u);
330347

331348
// Check values for add: {1.29, 2.31}
332-
EXPECT_THAT(InstMap.at(Env.Add),
349+
EXPECT_THAT(InstMap.at(AddInst),
333350
ElementsAre(DoubleNear(1.29, 1e-6), DoubleNear(2.31, 1e-6)));
334351

335352
// Check values for ret: {0.0, 0.}; Neither ret nor voidTy are present in
336353
// vocab
337-
EXPECT_THAT(InstMap.at(Env.Ret), ElementsAre(0.0, 0.0));
354+
EXPECT_THAT(InstMap.at(RetInst), ElementsAre(0.0, 0.0));
338355
}
339356

340-
TEST(IR2VecTest, GetBBVecMap) {
341-
GetterTestEnv Env;
342-
const auto &BBMap = Env.Emb->getBBVecMap();
357+
TEST_F(IR2VecTestFixture, GetBBVecMap) {
358+
auto Result = Embedder::create(IR2VecKind::Symbolic, *F, V);
359+
ASSERT_TRUE(static_cast<bool>(Result));
360+
auto Emb = std::move(*Result);
361+
362+
const auto &BBMap = Emb->getBBVecMap();
343363

344364
EXPECT_EQ(BBMap.size(), 1u);
345-
EXPECT_TRUE(BBMap.count(Env.BB));
346-
EXPECT_EQ(BBMap.at(Env.BB).size(), 2u);
365+
EXPECT_TRUE(BBMap.count(BB));
366+
EXPECT_EQ(BBMap.at(BB).size(), 2u);
347367

348368
// BB vector should be sum of add and ret: {1.29, 2.31} + {0.0, 0.0} =
349369
// {1.29, 2.31}
350-
EXPECT_THAT(BBMap.at(Env.BB),
370+
EXPECT_THAT(BBMap.at(BB),
351371
ElementsAre(DoubleNear(1.29, 1e-6), DoubleNear(2.31, 1e-6)));
352372
}
353373

354-
TEST(IR2VecTest, GetBBVector) {
355-
GetterTestEnv Env;
356-
const auto &BBVec = Env.Emb->getBBVector(*Env.BB);
374+
TEST_F(IR2VecTestFixture, GetBBVector) {
375+
auto Result = Embedder::create(IR2VecKind::Symbolic, *F, V);
376+
ASSERT_TRUE(static_cast<bool>(Result));
377+
auto Emb = std::move(*Result);
378+
379+
const auto &BBVec = Emb->getBBVector(*BB);
357380

358381
EXPECT_EQ(BBVec.size(), 2u);
359382
EXPECT_THAT(BBVec,
360383
ElementsAre(DoubleNear(1.29, 1e-6), DoubleNear(2.31, 1e-6)));
361384
}
362385

363-
TEST(IR2VecTest, GetFunctionVector) {
364-
GetterTestEnv Env;
365-
const auto &FuncVec = Env.Emb->getFunctionVector();
386+
TEST_F(IR2VecTestFixture, GetFunctionVector) {
387+
auto Result = Embedder::create(IR2VecKind::Symbolic, *F, V);
388+
ASSERT_TRUE(static_cast<bool>(Result));
389+
auto Emb = std::move(*Result);
390+
391+
const auto &FuncVec = Emb->getFunctionVector();
366392

367393
EXPECT_EQ(FuncVec.size(), 2u);
368394

@@ -371,4 +397,45 @@ TEST(IR2VecTest, GetFunctionVector) {
371397
ElementsAre(DoubleNear(1.29, 1e-6), DoubleNear(2.31, 1e-6)));
372398
}
373399

400+
TEST_F(IR2VecTestFixture, GetFunctionVectorWithCustomWeights) {
401+
setWeights(1.0, 1.0, 1.0);
402+
403+
auto Result = Embedder::create(IR2VecKind::Symbolic, *F, V);
404+
ASSERT_TRUE(static_cast<bool>(Result));
405+
auto Emb = std::move(*Result);
406+
407+
const auto &FuncVec = Emb->getFunctionVector();
408+
409+
EXPECT_EQ(FuncVec.size(), 2u);
410+
411+
// Expected: 1*([1.0 2.0] + [0.0 0.0]) + 1*([0.5 0.5] + [0.0 0.0]) + 1*([0.2
412+
// 0.3] + [0.0 0.0])
413+
EXPECT_THAT(FuncVec,
414+
ElementsAre(DoubleNear(1.7, 1e-6), DoubleNear(2.8, 1e-6)));
415+
}
416+
417+
TEST(IR2VecTest, IR2VecVocabAnalysisWithPrepopulatedVocab) {
418+
Vocab InitialVocab = {{"key1", {1.1, 2.2}}, {"key2", {3.3, 4.4}}};
419+
Vocab ExpectedVocab = InitialVocab;
420+
unsigned ExpectedDim = InitialVocab.begin()->second.size();
421+
422+
IR2VecVocabAnalysis VocabAnalysis(std::move(InitialVocab));
423+
424+
LLVMContext TestCtx;
425+
Module TestMod("TestModuleForVocabAnalysis", TestCtx);
426+
ModuleAnalysisManager MAM;
427+
IR2VecVocabResult Result = VocabAnalysis.run(TestMod, MAM);
428+
429+
EXPECT_TRUE(Result.isValid());
430+
ASSERT_FALSE(Result.getVocabulary().empty());
431+
EXPECT_EQ(Result.getDimension(), ExpectedDim);
432+
433+
const auto &ResultVocab = Result.getVocabulary();
434+
EXPECT_EQ(ResultVocab.size(), ExpectedVocab.size());
435+
for (const auto &pair : ExpectedVocab) {
436+
EXPECT_TRUE(ResultVocab.count(pair.first));
437+
EXPECT_THAT(ResultVocab.at(pair.first), ElementsAreArray(pair.second));
438+
}
439+
}
440+
374441
} // end anonymous namespace

0 commit comments

Comments
 (0)