Skip to content

Commit 1a051f1

Browse files
committed
Simplifying creation of Embedder
1 parent 2657262 commit 1a051f1

File tree

6 files changed

+33
-56
lines changed

6 files changed

+33
-56
lines changed

llvm/docs/MLGO.rst

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -479,14 +479,9 @@ embeddings can be computed and accessed via an ``ir2vec::Embedder`` instance.
479479

480480
// Assuming F is an llvm::Function&
481481
// For example, using IR2VecKind::Symbolic:
482-
Expected<std::unique_ptr<ir2vec::Embedder>> EmbOrErr =
482+
std::unique_ptr<ir2vec::Embedder> Emb =
483483
ir2vec::Embedder::create(IR2VecKind::Symbolic, F, Vocabulary);
484484

485-
if (auto Err = EmbOrErr.takeError()) {
486-
// Handle error in embedder creation
487-
return;
488-
}
489-
std::unique_ptr<ir2vec::Embedder> Emb = std::move(*EmbOrErr);
490485

491486
3. **Compute and Access Embeddings**:
492487
Call ``getFunctionVector()`` to get the embedding for the function.

llvm/include/llvm/Analysis/IR2Vec.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -170,8 +170,8 @@ class Embedder {
170170
virtual ~Embedder() = default;
171171

172172
/// Factory method to create an Embedder object.
173-
static Expected<std::unique_ptr<Embedder>>
174-
create(IR2VecKind Mode, const Function &F, const Vocab &Vocabulary);
173+
static std::unique_ptr<Embedder> create(IR2VecKind Mode, const Function &F,
174+
const Vocab &Vocabulary);
175175

176176
/// Returns a map containing instructions and the corresponding embeddings for
177177
/// the function F if it has been computed. If not, it computes the embeddings

llvm/lib/Analysis/FunctionPropertiesAnalysis.cpp

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -204,16 +204,12 @@ void FunctionPropertiesInfo::updateForBB(const BasicBlock &BB,
204204
// We instantiate the IR2Vec embedder each time, as having an unique
205205
// pointer to the embedder as member of the class would make it
206206
// non-copyable. Instantiating the embedder in itself is not costly.
207-
auto EmbOrErr = ir2vec::Embedder::create(IR2VecKind::Symbolic,
207+
auto Embedder = ir2vec::Embedder::create(IR2VecKind::Symbolic,
208208
*BB.getParent(), *IR2VecVocab);
209-
if (Error Err = EmbOrErr.takeError()) {
210-
handleAllErrors(std::move(Err), [&](const ErrorInfoBase &EI) {
211-
BB.getContext().emitError("Error creating IR2Vec embeddings: " +
212-
EI.message());
213-
});
209+
if (!Embedder) {
210+
BB.getContext().emitError("Error creating IR2Vec embeddings");
214211
return;
215212
}
216-
auto Embedder = std::move(*EmbOrErr);
217213
const auto &BBEmbedding = Embedder->getBBVector(BB);
218214
// Subtract BBEmbedding from Function embedding if the direction is -1,
219215
// and add it if the direction is +1.

llvm/lib/Analysis/IR2Vec.cpp

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -123,13 +123,14 @@ Embedder::Embedder(const Function &F, const Vocab &Vocabulary)
123123
Dimension(Vocabulary.begin()->second.size()), OpcWeight(::OpcWeight),
124124
TypeWeight(::TypeWeight), ArgWeight(::ArgWeight) {}
125125

126-
Expected<std::unique_ptr<Embedder>>
127-
Embedder::create(IR2VecKind Mode, const Function &F, const Vocab &Vocabulary) {
126+
std::unique_ptr<Embedder> Embedder::create(IR2VecKind Mode, const Function &F,
127+
const Vocab &Vocabulary) {
128128
switch (Mode) {
129129
case IR2VecKind::Symbolic:
130130
return std::make_unique<SymbolicEmbedder>(F, Vocabulary);
131131
}
132-
return make_error<StringError>("Unknown IR2VecKind", errc::invalid_argument);
132+
llvm_unreachable("Unknown IR2Vec kind");
133+
return nullptr;
133134
}
134135

135136
// FIXME: Currently lookups are string based. Use numeric Keys
@@ -388,17 +389,13 @@ PreservedAnalyses IR2VecPrinterPass::run(Module &M,
388389

389390
auto Vocab = IR2VecVocabResult.getVocabulary();
390391
for (Function &F : M) {
391-
Expected<std::unique_ptr<Embedder>> EmbOrErr =
392+
std::unique_ptr<Embedder> Emb =
392393
Embedder::create(IR2VecKind::Symbolic, F, Vocab);
393-
if (auto Err = EmbOrErr.takeError()) {
394-
handleAllErrors(std::move(Err), [&](const ErrorInfoBase &EI) {
395-
OS << "Error creating IR2Vec embeddings: " << EI.message() << "\n";
396-
});
394+
if (!Emb) {
395+
OS << "Error creating IR2Vec embeddings \n";
397396
continue;
398397
}
399398

400-
std::unique_ptr<Embedder> Emb = std::move(*EmbOrErr);
401-
402399
OS << "IR2Vec embeddings for function " << F.getName() << ":\n";
403400
OS << "Function vector: ";
404401
Emb->getFunctionVector().print(OS);

llvm/unittests/Analysis/FunctionPropertiesAnalysisTest.cpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -127,10 +127,9 @@ class FunctionPropertiesAnalysisTest : public testing::Test {
127127
}
128128

129129
std::unique_ptr<ir2vec::Embedder> createEmbedder(const Function &F) {
130-
auto EmbResult =
131-
ir2vec::Embedder::create(IR2VecKind::Symbolic, F, Vocabulary);
132-
EXPECT_TRUE(static_cast<bool>(EmbResult));
133-
return std::move(*EmbResult);
130+
auto Emb = ir2vec::Embedder::create(IR2VecKind::Symbolic, F, Vocabulary);
131+
EXPECT_TRUE(static_cast<bool>(Emb));
132+
return std::move(Emb);
134133
}
135134
};
136135

llvm/unittests/Analysis/IR2VecTest.cpp

Lines changed: 17 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -216,10 +216,7 @@ TEST(IR2VecTest, CreateSymbolicEmbedder) {
216216
FunctionType *FTy = FunctionType::get(Type::getVoidTy(Ctx), false);
217217
Function *F = Function::Create(FTy, Function::ExternalLinkage, "f", M);
218218

219-
auto Result = Embedder::create(IR2VecKind::Symbolic, *F, V);
220-
EXPECT_TRUE(static_cast<bool>(Result));
221-
222-
auto *Emb = Result->get();
219+
auto Emb = Embedder::create(IR2VecKind::Symbolic, *F, V);
223220
EXPECT_NE(Emb, nullptr);
224221
}
225222

@@ -231,15 +228,16 @@ TEST(IR2VecTest, CreateInvalidMode) {
231228
FunctionType *FTy = FunctionType::get(Type::getVoidTy(Ctx), false);
232229
Function *F = Function::Create(FTy, Function::ExternalLinkage, "f", M);
233230

234-
// static_cast an invalid int to IR2VecKind
231+
// static_cast an invalid int to IR2VecKind
232+
#ifndef NDEBUG
233+
#if GTEST_HAS_DEATH_TEST
234+
EXPECT_DEATH(Embedder::create(static_cast<IR2VecKind>(-1), *F, V),
235+
"Unknown IR2Vec kind");
236+
#endif // GTEST_HAS_DEATH_TEST
237+
#else
235238
auto Result = Embedder::create(static_cast<IR2VecKind>(-1), *F, V);
236239
EXPECT_FALSE(static_cast<bool>(Result));
237-
238-
std::string ErrMsg;
239-
llvm::handleAllErrors(
240-
Result.takeError(),
241-
[&](const llvm::ErrorInfoBase &EIB) { ErrMsg = EIB.message(); });
242-
EXPECT_NE(ErrMsg.find("Unknown IR2VecKind"), std::string::npos);
240+
#endif // NDEBUG
243241
}
244242

245243
TEST(IR2VecTest, LookupVocab) {
@@ -298,10 +296,6 @@ class IR2VecTestFixture : public ::testing::Test {
298296
Instruction *AddInst = nullptr;
299297
Instruction *RetInst = nullptr;
300298

301-
float OriginalOpcWeight = ::OpcWeight;
302-
float OriginalTypeWeight = ::TypeWeight;
303-
float OriginalArgWeight = ::ArgWeight;
304-
305299
void SetUp() override {
306300
V = {{"add", {1.0, 2.0}},
307301
{"integerTy", {0.25, 0.25}},
@@ -325,9 +319,8 @@ class IR2VecTestFixture : public ::testing::Test {
325319
};
326320

327321
TEST_F(IR2VecTestFixture, GetInstVecMap) {
328-
auto Result = Embedder::create(IR2VecKind::Symbolic, *F, V);
329-
ASSERT_TRUE(static_cast<bool>(Result));
330-
auto Emb = std::move(*Result);
322+
auto Emb = Embedder::create(IR2VecKind::Symbolic, *F, V);
323+
ASSERT_TRUE(static_cast<bool>(Emb));
331324

332325
const auto &InstMap = Emb->getInstVecMap();
333326

@@ -348,9 +341,8 @@ TEST_F(IR2VecTestFixture, GetInstVecMap) {
348341
}
349342

350343
TEST_F(IR2VecTestFixture, GetBBVecMap) {
351-
auto Result = Embedder::create(IR2VecKind::Symbolic, *F, V);
352-
ASSERT_TRUE(static_cast<bool>(Result));
353-
auto Emb = std::move(*Result);
344+
auto Emb = Embedder::create(IR2VecKind::Symbolic, *F, V);
345+
ASSERT_TRUE(static_cast<bool>(Emb));
354346

355347
const auto &BBMap = Emb->getBBVecMap();
356348

@@ -365,9 +357,8 @@ TEST_F(IR2VecTestFixture, GetBBVecMap) {
365357
}
366358

367359
TEST_F(IR2VecTestFixture, GetBBVector) {
368-
auto Result = Embedder::create(IR2VecKind::Symbolic, *F, V);
369-
ASSERT_TRUE(static_cast<bool>(Result));
370-
auto Emb = std::move(*Result);
360+
auto Emb = Embedder::create(IR2VecKind::Symbolic, *F, V);
361+
ASSERT_TRUE(static_cast<bool>(Emb));
371362

372363
const auto &BBVec = Emb->getBBVector(*BB);
373364

@@ -377,9 +368,8 @@ TEST_F(IR2VecTestFixture, GetBBVector) {
377368
}
378369

379370
TEST_F(IR2VecTestFixture, GetFunctionVector) {
380-
auto Result = Embedder::create(IR2VecKind::Symbolic, *F, V);
381-
ASSERT_TRUE(static_cast<bool>(Result));
382-
auto Emb = std::move(*Result);
371+
auto Emb = Embedder::create(IR2VecKind::Symbolic, *F, V);
372+
ASSERT_TRUE(static_cast<bool>(Emb));
383373

384374
const auto &FuncVec = Emb->getFunctionVector();
385375

0 commit comments

Comments
 (0)