Skip to content

[IR2Vec] Minor vocab changes and exposing weights #143200

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 13, 2025

Conversation

svkeerthy
Copy link
Contributor

@svkeerthy svkeerthy commented Jun 6, 2025

This PR changes some asserts in Vocab to hard checks that emit error and exposes flags and constructor to help in unit tests.

(Tracking issue - #141817)

Copy link
Contributor Author

svkeerthy commented Jun 6, 2025

@svkeerthy svkeerthy changed the title Vocab changes1 [IR2Vec] Minor vocab changes and exposing weights Jun 6, 2025
@svkeerthy svkeerthy marked this pull request as ready for review June 6, 2025 21:08
@llvmbot
Copy link
Member

llvmbot commented Jun 6, 2025

@llvm/pr-subscribers-llvm-analysis

@llvm/pr-subscribers-mlgo

Author: S. VenkataKeerthy (svkeerthy)

Changes

This PR changes some asserts in Vocab to hard checks that emit error and exposes flags and constructor to help in unit tests.

(Tracking issue - #141817)


Full diff: https://github.com/llvm/llvm-project/pull/143200.diff

3 Files Affected:

  • (modified) llvm/include/llvm/Analysis/IR2Vec.h (+10)
  • (modified) llvm/lib/Analysis/IR2Vec.cpp (+53-28)
  • (modified) llvm/unittests/Analysis/IR2VecTest.cpp (+102-35)
diff --git a/llvm/include/llvm/Analysis/IR2Vec.h b/llvm/include/llvm/Analysis/IR2Vec.h
index 930b13f079796..1bd80ed65d434 100644
--- a/llvm/include/llvm/Analysis/IR2Vec.h
+++ b/llvm/include/llvm/Analysis/IR2Vec.h
@@ -31,7 +31,9 @@
 
 #include "llvm/ADT/DenseMap.h"
 #include "llvm/IR/PassManager.h"
+#include "llvm/Support/CommandLine.h"
 #include "llvm/Support/ErrorOr.h"
+#include "llvm/Support/JSON.h"
 #include <map>
 
 namespace llvm {
@@ -43,6 +45,7 @@ class Function;
 class Type;
 class Value;
 class raw_ostream;
+class LLVMContext;
 
 /// IR2Vec computes two kinds of embeddings: Symbolic and Flow-aware.
 /// Symbolic embeddings capture the "syntactic" and "statistical correlation"
@@ -53,6 +56,11 @@ class raw_ostream;
 enum class IR2VecKind { Symbolic };
 
 namespace ir2vec {
+
+LLVM_ABI extern cl::opt<float> OpcWeight;
+LLVM_ABI extern cl::opt<float> TypeWeight;
+LLVM_ABI extern cl::opt<float> ArgWeight;
+
 /// Embedding is a ADT that wraps std::vector<double>. It provides
 /// additional functionality for arithmetic and comparison operations.
 struct Embedding : public std::vector<double> {
@@ -187,10 +195,12 @@ class IR2VecVocabResult {
 class IR2VecVocabAnalysis : public AnalysisInfoMixin<IR2VecVocabAnalysis> {
   ir2vec::Vocab Vocabulary;
   Error readVocabulary();
+  void emitError(Error Err, LLVMContext &Ctx);
 
 public:
   static AnalysisKey Key;
   IR2VecVocabAnalysis() = default;
+  explicit IR2VecVocabAnalysis(ir2vec::Vocab &&Vocab);
   using Result = IR2VecVocabResult;
   Result run(Module &M, ModuleAnalysisManager &MAM);
 };
diff --git a/llvm/lib/Analysis/IR2Vec.cpp b/llvm/lib/Analysis/IR2Vec.cpp
index 8ee8e5b0ff74e..3333e751f10b8 100644
--- a/llvm/lib/Analysis/IR2Vec.cpp
+++ b/llvm/lib/Analysis/IR2Vec.cpp
@@ -16,13 +16,11 @@
 #include "llvm/ADT/Statistic.h"
 #include "llvm/IR/Module.h"
 #include "llvm/IR/PassManager.h"
-#include "llvm/Support/CommandLine.h"
 #include "llvm/Support/Debug.h"
 #include "llvm/Support/Errc.h"
 #include "llvm/Support/Error.h"
 #include "llvm/Support/ErrorHandling.h"
 #include "llvm/Support/Format.h"
-#include "llvm/Support/JSON.h"
 #include "llvm/Support/MemoryBuffer.h"
 
 using namespace llvm;
@@ -33,6 +31,8 @@ using namespace ir2vec;
 STATISTIC(VocabMissCounter,
           "Number of lookups to entites not present in the vocabulary");
 
+namespace llvm {
+namespace ir2vec {
 static cl::OptionCategory IR2VecCategory("IR2Vec Options");
 
 // FIXME: Use a default vocab when not specified
@@ -40,18 +40,20 @@ static cl::opt<std::string>
     VocabFile("ir2vec-vocab-path", cl::Optional,
               cl::desc("Path to the vocabulary file for IR2Vec"), cl::init(""),
               cl::cat(IR2VecCategory));
-static cl::opt<float> OpcWeight("ir2vec-opc-weight", cl::Optional,
-                                cl::init(1.0),
-                                cl::desc("Weight for opcode embeddings"),
-                                cl::cat(IR2VecCategory));
-static cl::opt<float> TypeWeight("ir2vec-type-weight", cl::Optional,
-                                 cl::init(0.5),
-                                 cl::desc("Weight for type embeddings"),
-                                 cl::cat(IR2VecCategory));
-static cl::opt<float> ArgWeight("ir2vec-arg-weight", cl::Optional,
-                                cl::init(0.2),
-                                cl::desc("Weight for argument embeddings"),
-                                cl::cat(IR2VecCategory));
+LLVM_ABI cl::opt<float> OpcWeight("ir2vec-opc-weight", cl::Optional,
+                                  cl::init(1.0),
+                                  cl::desc("Weight for opcode embeddings"),
+                                  cl::cat(IR2VecCategory));
+LLVM_ABI cl::opt<float> TypeWeight("ir2vec-type-weight", cl::Optional,
+                                   cl::init(0.5),
+                                   cl::desc("Weight for type embeddings"),
+                                   cl::cat(IR2VecCategory));
+LLVM_ABI cl::opt<float> ArgWeight("ir2vec-arg-weight", cl::Optional,
+                                  cl::init(0.2),
+                                  cl::desc("Weight for argument embeddings"),
+                                  cl::cat(IR2VecCategory));
+} // namespace ir2vec
+} // namespace llvm
 
 AnalysisKey IR2VecVocabAnalysis::Key;
 
@@ -83,9 +85,10 @@ void Embedding::scaleAndAdd(const Embedding &Src, float Factor) {
 bool Embedding::approximatelyEquals(const Embedding &RHS,
                                     double Tolerance) const {
   assert(this->size() == RHS.size() && "Vectors must have the same dimension");
-  for (size_t i = 0; i < this->size(); ++i)
+  for (size_t i = 0; i < this->size(); ++i) {
     if (std::abs((*this)[i] - RHS[i]) > Tolerance)
       return false;
+  }
   return true;
 }
 
@@ -254,35 +257,57 @@ Error IR2VecVocabAnalysis::readVocabulary() {
     return createStringError(errc::illegal_byte_sequence,
                              "Unable to parse the vocabulary");
   }
-  assert(Vocabulary.size() > 0 && "Vocabulary is empty");
+
+  if (Vocabulary.empty()) {
+    return createStringError(errc::illegal_byte_sequence,
+                             "Vocabulary is empty");
+  }
 
   unsigned Dim = Vocabulary.begin()->second.size();
-  assert(Dim > 0 && "Dimension of vocabulary is zero");
-  (void)Dim;
-  assert(std::all_of(Vocabulary.begin(), Vocabulary.end(),
-                     [Dim](const std::pair<StringRef, Embedding> &Entry) {
-                       return Entry.second.size() == Dim;
-                     }) &&
-         "All vectors in the vocabulary are not of the same dimension");
+  if (Dim == 0) {
+    return createStringError(errc::illegal_byte_sequence,
+                             "Dimension of vocabulary is zero");
+  }
+
+  if (!std::all_of(Vocabulary.begin(), Vocabulary.end(),
+                   [Dim](const std::pair<StringRef, Embedding> &Entry) {
+                     return Entry.second.size() == Dim;
+                   })) {
+    return createStringError(
+        errc::illegal_byte_sequence,
+        "All vectors in the vocabulary are not of the same dimension");
+  }
+
   return Error::success();
 }
 
+IR2VecVocabAnalysis::IR2VecVocabAnalysis(Vocab &&Vocabulary)
+    : Vocabulary(std::move(Vocabulary)) {}
+
+void IR2VecVocabAnalysis::emitError(Error Err, LLVMContext &Ctx) {
+  handleAllErrors(std::move(Err), [&](const ErrorInfoBase &EI) {
+    Ctx.emitError("Error reading vocabulary: " + EI.message());
+  });
+}
+
 IR2VecVocabAnalysis::Result
 IR2VecVocabAnalysis::run(Module &M, ModuleAnalysisManager &AM) {
   auto Ctx = &M.getContext();
+  // FIXME: Scale the vocabulary once. This would avoid scaling per use later.
+  // If vocabulary is already populated by the constructor, use it.
+  if (!Vocabulary.empty())
+    return IR2VecVocabResult(std::move(Vocabulary));
+
+  // Otherwise, try to read from the vocabulary file.
   if (VocabFile.empty()) {
     // FIXME: Use default vocabulary
     Ctx->emitError("IR2Vec vocabulary file path not specified");
     return IR2VecVocabResult(); // Return invalid result
   }
   if (auto Err = readVocabulary()) {
-    handleAllErrors(std::move(Err), [&](const ErrorInfoBase &EI) {
-      Ctx->emitError("Error reading vocabulary: " + EI.message());
-    });
+    emitError(std::move(Err), *Ctx);
     return IR2VecVocabResult();
   }
-  // FIXME: Scale the vocabulary here once. This would avoid scaling per use
-  // later.
   return IR2VecVocabResult(std::move(Vocabulary));
 }
 
diff --git a/llvm/unittests/Analysis/IR2VecTest.cpp b/llvm/unittests/Analysis/IR2VecTest.cpp
index 7259a8a2fe20a..7d5787243d12d 100644
--- a/llvm/unittests/Analysis/IR2VecTest.cpp
+++ b/llvm/unittests/Analysis/IR2VecTest.cpp
@@ -199,25 +199,30 @@ TEST(IR2VecTest, IR2VecVocabResultValidity) {
   EXPECT_EQ(validResult.getDimension(), 2u);
 }
 
-// Helper to create a minimal function and embedder for getter tests
-struct GetterTestEnv {
-  Vocab V = {};
+// Fixture for IR2Vec tests requiring IR setup and weight management.
+class IR2VecTestFixture : public ::testing::Test {
+protected:
+  Vocab V;
   LLVMContext Ctx;
-  std::unique_ptr<Module> M = nullptr;
+  std::unique_ptr<Module> M;
   Function *F = nullptr;
   BasicBlock *BB = nullptr;
-  Instruction *Add = nullptr;
-  Instruction *Ret = nullptr;
-  std::unique_ptr<Embedder> Emb = nullptr;
+  Instruction *AddInst = nullptr;
+  Instruction *RetInst = nullptr;
 
-  GetterTestEnv() {
+  float OriginalOpcWeight = ::OpcWeight;
+  float OriginalTypeWeight = ::TypeWeight;
+  float OriginalArgWeight = ::ArgWeight;
+
+  void SetUp() override {
     V = {{"add", {1.0, 2.0}},
          {"integerTy", {0.5, 0.5}},
          {"constant", {0.2, 0.3}},
          {"variable", {0.0, 0.0}},
          {"unknownTy", {0.0, 0.0}}};
 
-    M = std::make_unique<Module>("M", Ctx);
+    // Setup IR
+    M = std::make_unique<Module>("TestM", Ctx);
     FunctionType *FTy = FunctionType::get(
         Type::getInt32Ty(Ctx), {Type::getInt32Ty(Ctx), Type::getInt32Ty(Ctx)},
         false);
@@ -226,61 +231,82 @@ struct GetterTestEnv {
     Argument *Arg = F->getArg(0);
     llvm::Value *Const = ConstantInt::get(Type::getInt32Ty(Ctx), 42);
 
-    Add = BinaryOperator::CreateAdd(Arg, Const, "add", BB);
-    Ret = ReturnInst::Create(Ctx, Add, BB);
+    AddInst = BinaryOperator::CreateAdd(Arg, Const, "add", BB);
+    RetInst = ReturnInst::Create(Ctx, AddInst, BB);
+  }
+
+  void setWeights(float OpcWeight, float TypeWeight, float ArgWeight) {
+    ::OpcWeight = OpcWeight;
+    ::TypeWeight = TypeWeight;
+    ::ArgWeight = ArgWeight;
+  }
 
-    auto Result = Embedder::create(IR2VecKind::Symbolic, *F, V);
-    EXPECT_TRUE(static_cast<bool>(Result));
-    Emb = std::move(*Result);
+  void TearDown() override {
+    // Restore original global weights
+    ::OpcWeight = OriginalOpcWeight;
+    ::TypeWeight = OriginalTypeWeight;
+    ::ArgWeight = OriginalArgWeight;
   }
 };
 
-TEST(IR2VecTest, GetInstVecMap) {
-  GetterTestEnv Env;
-  const auto &InstMap = Env.Emb->getInstVecMap();
+TEST_F(IR2VecTestFixture, GetInstVecMap) {
+  auto Result = Embedder::create(IR2VecKind::Symbolic, *F, V);
+  ASSERT_TRUE(static_cast<bool>(Result));
+  auto Emb = std::move(*Result);
+
+  const auto &InstMap = Emb->getInstVecMap();
 
   EXPECT_EQ(InstMap.size(), 2u);
-  EXPECT_TRUE(InstMap.count(Env.Add));
-  EXPECT_TRUE(InstMap.count(Env.Ret));
+  EXPECT_TRUE(InstMap.count(AddInst));
+  EXPECT_TRUE(InstMap.count(RetInst));
 
-  EXPECT_EQ(InstMap.at(Env.Add).size(), 2u);
-  EXPECT_EQ(InstMap.at(Env.Ret).size(), 2u);
+  EXPECT_EQ(InstMap.at(AddInst).size(), 2u);
+  EXPECT_EQ(InstMap.at(RetInst).size(), 2u);
 
   // Check values for add: {1.29, 2.31}
-  EXPECT_THAT(InstMap.at(Env.Add),
+  EXPECT_THAT(InstMap.at(AddInst),
               ElementsAre(DoubleNear(1.29, 1e-6), DoubleNear(2.31, 1e-6)));
 
   // Check values for ret: {0.0, 0.}; Neither ret nor voidTy are present in
   // vocab
-  EXPECT_THAT(InstMap.at(Env.Ret), ElementsAre(0.0, 0.0));
+  EXPECT_THAT(InstMap.at(RetInst), ElementsAre(0.0, 0.0));
 }
 
-TEST(IR2VecTest, GetBBVecMap) {
-  GetterTestEnv Env;
-  const auto &BBMap = Env.Emb->getBBVecMap();
+TEST_F(IR2VecTestFixture, GetBBVecMap) {
+  auto Result = Embedder::create(IR2VecKind::Symbolic, *F, V);
+  ASSERT_TRUE(static_cast<bool>(Result));
+  auto Emb = std::move(*Result);
+
+  const auto &BBMap = Emb->getBBVecMap();
 
   EXPECT_EQ(BBMap.size(), 1u);
-  EXPECT_TRUE(BBMap.count(Env.BB));
-  EXPECT_EQ(BBMap.at(Env.BB).size(), 2u);
+  EXPECT_TRUE(BBMap.count(BB));
+  EXPECT_EQ(BBMap.at(BB).size(), 2u);
 
   // BB vector should be sum of add and ret: {1.29, 2.31} + {0.0, 0.0} =
   // {1.29, 2.31}
-  EXPECT_THAT(BBMap.at(Env.BB),
+  EXPECT_THAT(BBMap.at(BB),
               ElementsAre(DoubleNear(1.29, 1e-6), DoubleNear(2.31, 1e-6)));
 }
 
-TEST(IR2VecTest, GetBBVector) {
-  GetterTestEnv Env;
-  const auto &BBVec = Env.Emb->getBBVector(*Env.BB);
+TEST_F(IR2VecTestFixture, GetBBVector) {
+  auto Result = Embedder::create(IR2VecKind::Symbolic, *F, V);
+  ASSERT_TRUE(static_cast<bool>(Result));
+  auto Emb = std::move(*Result);
+
+  const auto &BBVec = Emb->getBBVector(*BB);
 
   EXPECT_EQ(BBVec.size(), 2u);
   EXPECT_THAT(BBVec,
               ElementsAre(DoubleNear(1.29, 1e-6), DoubleNear(2.31, 1e-6)));
 }
 
-TEST(IR2VecTest, GetFunctionVector) {
-  GetterTestEnv Env;
-  const auto &FuncVec = Env.Emb->getFunctionVector();
+TEST_F(IR2VecTestFixture, GetFunctionVector) {
+  auto Result = Embedder::create(IR2VecKind::Symbolic, *F, V);
+  ASSERT_TRUE(static_cast<bool>(Result));
+  auto Emb = std::move(*Result);
+
+  const auto &FuncVec = Emb->getFunctionVector();
 
   EXPECT_EQ(FuncVec.size(), 2u);
 
@@ -289,4 +315,45 @@ TEST(IR2VecTest, GetFunctionVector) {
               ElementsAre(DoubleNear(1.29, 1e-6), DoubleNear(2.31, 1e-6)));
 }
 
+TEST_F(IR2VecTestFixture, GetFunctionVectorWithCustomWeights) {
+  setWeights(1.0, 1.0, 1.0);
+
+  auto Result = Embedder::create(IR2VecKind::Symbolic, *F, V);
+  ASSERT_TRUE(static_cast<bool>(Result));
+  auto Emb = std::move(*Result);
+
+  const auto &FuncVec = Emb->getFunctionVector();
+
+  EXPECT_EQ(FuncVec.size(), 2u);
+
+  // Expected: 1*([1.0 2.0] + [0.0 0.0]) + 1*([0.5 0.5] + [0.0 0.0]) + 1*([0.2
+  // 0.3] + [0.0 0.0])
+  EXPECT_THAT(FuncVec,
+              ElementsAre(DoubleNear(1.7, 1e-6), DoubleNear(2.8, 1e-6)));
+}
+
+TEST(IR2VecTest, IR2VecVocabAnalysisWithPrepopulatedVocab) {
+  Vocab InitialVocab = {{"key1", {1.1, 2.2}}, {"key2", {3.3, 4.4}}};
+  Vocab ExpectedVocab = InitialVocab;
+  unsigned ExpectedDim = InitialVocab.begin()->second.size();
+
+  IR2VecVocabAnalysis VocabAnalysis(std::move(InitialVocab));
+
+  LLVMContext TestCtx;
+  Module TestMod("TestModuleForVocabAnalysis", TestCtx);
+  ModuleAnalysisManager MAM;
+  IR2VecVocabResult Result = VocabAnalysis.run(TestMod, MAM);
+
+  EXPECT_TRUE(Result.isValid());
+  ASSERT_FALSE(Result.getVocabulary().empty());
+  EXPECT_EQ(Result.getDimension(), ExpectedDim);
+
+  const auto &ResultVocab = Result.getVocabulary();
+  EXPECT_EQ(ResultVocab.size(), ExpectedVocab.size());
+  for (const auto &pair : ExpectedVocab) {
+    EXPECT_TRUE(ResultVocab.count(pair.first));
+    EXPECT_THAT(ResultVocab.at(pair.first), ElementsAreArray(pair.second));
+  }
+}
+
 } // end anonymous namespace

@svkeerthy svkeerthy force-pushed the users/svkeerthy/06-06-vocab_changes1 branch from d62a38b to 750cb2f Compare June 9, 2025 18:10
@svkeerthy svkeerthy force-pushed the users/svkeerthy/06-06-embedding branch from 8e5c13a to 602c2d3 Compare June 9, 2025 18:10
@svkeerthy svkeerthy force-pushed the users/svkeerthy/06-06-vocab_changes1 branch 2 times, most recently from c1842ec to 7f2012c Compare June 9, 2025 20:42
@svkeerthy svkeerthy force-pushed the users/svkeerthy/06-06-embedding branch from 602c2d3 to 6817aa9 Compare June 9, 2025 20:42
@svkeerthy svkeerthy force-pushed the users/svkeerthy/06-06-vocab_changes1 branch from 7f2012c to d3468ab Compare June 10, 2025 04:23
@svkeerthy svkeerthy force-pushed the users/svkeerthy/06-06-embedding branch from 6817aa9 to 9c05884 Compare June 10, 2025 04:23
@svkeerthy svkeerthy force-pushed the users/svkeerthy/06-06-vocab_changes1 branch from d3468ab to 96e4a8b Compare June 10, 2025 04:30
@svkeerthy svkeerthy force-pushed the users/svkeerthy/06-06-embedding branch from 9c05884 to dfe59f2 Compare June 10, 2025 04:30
@@ -53,7 +56,12 @@ class raw_ostream;
enum class IR2VecKind { Symbolic };

namespace ir2vec {
/// Embedding is a datavtype that wraps std::vector<double>. It provides

LLVM_ABI extern cl::opt<float> OpcWeight;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do these need an LLVM_ABI tag?

If this is actually needed, this seems like something that should be split out of this patch.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#136623 adds LLVM_ABI tags to some globals and externs in Analysis. Removing it for now. Can be added later if necessary.

@svkeerthy svkeerthy force-pushed the users/svkeerthy/06-06-embedding branch from dfe59f2 to 7a33ef8 Compare June 10, 2025 18:14
@svkeerthy svkeerthy force-pushed the users/svkeerthy/06-06-vocab_changes1 branch 2 times, most recently from d639ce4 to c200f67 Compare June 10, 2025 21:22
Base automatically changed from users/svkeerthy/06-06-embedding to main June 10, 2025 22:12
@svkeerthy svkeerthy force-pushed the users/svkeerthy/06-06-vocab_changes1 branch from c200f67 to 8685c74 Compare June 10, 2025 22:13
@svkeerthy svkeerthy requested a review from boomanaiden154 June 12, 2025 18:14
Copy link
Contributor Author

svkeerthy commented Jun 13, 2025

Merge activity

  • Jun 13, 5:41 PM UTC: A user started a stack merge that includes this pull request via Graphite.
  • Jun 13, 5:43 PM UTC: @svkeerthy merged this pull request with Graphite.

@svkeerthy svkeerthy merged commit 09c54c2 into main Jun 13, 2025
7 checks passed
@svkeerthy svkeerthy deleted the users/svkeerthy/06-06-vocab_changes1 branch June 13, 2025 17:43
tomtor pushed a commit to tomtor/llvm-project that referenced this pull request Jun 14, 2025
This PR changes some asserts in Vocab to hard checks that emit error and exposes flags and constructor to help in unit tests.

(Tracking issue - llvm#141817)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants