Skip to content

Commit 8e5c13a

Browse files
committed
Embedding
1 parent d8bfb47 commit 8e5c13a

File tree

3 files changed

+111
-41
lines changed

3 files changed

+111
-41
lines changed

llvm/include/llvm/Analysis/IR2Vec.h

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,24 @@ class raw_ostream;
5353
enum class IR2VecKind { Symbolic };
5454

5555
namespace ir2vec {
56-
using Embedding = std::vector<double>;
56+
/// Embedding is a ADT that wraps std::vector<double>. It provides
57+
/// additional functionality for arithmetic and comparison operations.
58+
struct Embedding : public std::vector<double> {
59+
using std::vector<double>::vector;
60+
61+
/// Arithmetic operators
62+
Embedding &operator+=(const Embedding &RHS);
63+
Embedding &operator-=(const Embedding &RHS);
64+
65+
/// Adds Src Embedding scaled by Factor with the called Embedding.
66+
/// Called_Embedding += Src * Factor
67+
void scaleAndAdd(const Embedding &Src, float Factor);
68+
69+
/// Returns true if the embedding is approximately equal to the RHS embedding
70+
/// within the specified tolerance.
71+
bool approximatelyEquals(const Embedding &RHS, double Tolerance = 1e-6) const;
72+
};
73+
5774
using InstEmbeddingsMap = DenseMap<const Instruction *, Embedding>;
5875
using BBEmbeddingsMap = DenseMap<const BasicBlock *, Embedding>;
5976
// FIXME: Current the keys are strings. This can be changed to
@@ -99,13 +116,6 @@ class Embedder {
99116
/// zero vector.
100117
Embedding lookupVocab(const std::string &Key) const;
101118

102-
/// Adds two vectors: Dst += Src
103-
static void addVectors(Embedding &Dst, const Embedding &Src);
104-
105-
/// Adds Src vector scaled by Factor to Dst vector: Dst += Src * Factor
106-
static void addScaledVector(Embedding &Dst, const Embedding &Src,
107-
float Factor);
108-
109119
public:
110120
virtual ~Embedder() = default;
111121

llvm/lib/Analysis/IR2Vec.cpp

Lines changed: 39 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,40 @@ static cl::opt<float> ArgWeight("ir2vec-arg-weight", cl::Optional,
5555

5656
AnalysisKey IR2VecVocabAnalysis::Key;
5757

58+
// ==----------------------------------------------------------------------===//
59+
// Embedding
60+
//===----------------------------------------------------------------------===//
61+
62+
Embedding &Embedding::operator+=(const Embedding &RHS) {
63+
assert(this->size() == RHS.size() && "Vectors must have the same dimension");
64+
std::transform(this->begin(), this->end(), RHS.begin(), this->begin(),
65+
std::plus<double>());
66+
return *this;
67+
}
68+
69+
Embedding &Embedding::operator-=(const Embedding &RHS) {
70+
assert(this->size() == RHS.size() && "Vectors must have the same dimension");
71+
std::transform(this->begin(), this->end(), RHS.begin(), this->begin(),
72+
std::minus<double>());
73+
return *this;
74+
}
75+
76+
void Embedding::scaleAndAdd(const Embedding &Src, float Factor) {
77+
assert(this->size() == Src.size() && "Vectors must have the same dimension");
78+
for (size_t i = 0; i < this->size(); ++i) {
79+
(*this)[i] += Src[i] * Factor;
80+
}
81+
}
82+
83+
bool Embedding::approximatelyEquals(const Embedding &RHS,
84+
double Tolerance) const {
85+
assert(this->size() == RHS.size() && "Vectors must have the same dimension");
86+
for (size_t i = 0; i < this->size(); ++i)
87+
if (std::abs((*this)[i] - RHS[i]) > Tolerance)
88+
return false;
89+
return true;
90+
}
91+
5892
// ==----------------------------------------------------------------------===//
5993
// Embedder and its subclasses
6094
//===----------------------------------------------------------------------===//
@@ -73,20 +107,6 @@ Embedder::create(IR2VecKind Mode, const Function &F, const Vocab &Vocabulary) {
73107
return make_error<StringError>("Unknown IR2VecKind", errc::invalid_argument);
74108
}
75109

76-
void Embedder::addVectors(Embedding &Dst, const Embedding &Src) {
77-
assert(Dst.size() == Src.size() && "Vectors must have the same dimension");
78-
std::transform(Dst.begin(), Dst.end(), Src.begin(), Dst.begin(),
79-
std::plus<double>());
80-
}
81-
82-
void Embedder::addScaledVector(Embedding &Dst, const Embedding &Src,
83-
float Factor) {
84-
assert(Dst.size() == Src.size() && "Vectors must have the same dimension");
85-
for (size_t i = 0; i < Dst.size(); ++i) {
86-
Dst[i] += Src[i] * Factor;
87-
}
88-
}
89-
90110
// FIXME: Currently lookups are string based. Use numeric Keys
91111
// for efficiency
92112
Embedding Embedder::lookupVocab(const std::string &Key) const {
@@ -164,20 +184,20 @@ void SymbolicEmbedder::computeEmbeddings(const BasicBlock &BB) const {
164184
Embedding InstVector(Dimension, 0);
165185

166186
const auto OpcVec = lookupVocab(I.getOpcodeName());
167-
addScaledVector(InstVector, OpcVec, OpcWeight);
187+
InstVector.scaleAndAdd(OpcVec, OpcWeight);
168188

169189
// FIXME: Currently lookups are string based. Use numeric Keys
170190
// for efficiency.
171191
const auto Type = I.getType();
172192
const auto TypeVec = getTypeEmbedding(Type);
173-
addScaledVector(InstVector, TypeVec, TypeWeight);
193+
InstVector.scaleAndAdd(TypeVec, TypeWeight);
174194

175195
for (const auto &Op : I.operands()) {
176196
const auto OperandVec = getOperandEmbedding(Op.get());
177-
addScaledVector(InstVector, OperandVec, ArgWeight);
197+
InstVector.scaleAndAdd(OperandVec, ArgWeight);
178198
}
179199
InstVecMap[&I] = InstVector;
180-
addVectors(BBVector, InstVector);
200+
BBVector += InstVector;
181201
}
182202
BBVecMap[&BB] = BBVector;
183203
}
@@ -187,7 +207,7 @@ void SymbolicEmbedder::computeEmbeddings() const {
187207
return;
188208
for (const auto &BB : F) {
189209
computeEmbeddings(BB);
190-
addVectors(FuncVector, BBVecMap[&BB]);
210+
FuncVector += BBVecMap[&BB];
191211
}
192212
}
193213

llvm/unittests/Analysis/IR2VecTest.cpp

Lines changed: 54 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -32,13 +32,6 @@ class TestableEmbedder : public Embedder {
3232
void computeEmbeddings() const override {}
3333
void computeEmbeddings(const BasicBlock &BB) const override {}
3434
using Embedder::lookupVocab;
35-
static void addVectors(Embedding &Dst, const Embedding &Src) {
36-
Embedder::addVectors(Dst, Src);
37-
}
38-
static void addScaledVector(Embedding &Dst, const Embedding &Src,
39-
float Factor) {
40-
Embedder::addScaledVector(Dst, Src, Factor);
41-
}
4235
};
4336

4437
TEST(IR2VecTest, CreateSymbolicEmbedder) {
@@ -79,37 +72,83 @@ TEST(IR2VecTest, AddVectors) {
7972
Embedding E1 = {1.0, 2.0, 3.0};
8073
Embedding E2 = {0.5, 1.5, -1.0};
8174

82-
TestableEmbedder::addVectors(E1, E2);
75+
E1 += E2;
8376
EXPECT_THAT(E1, ElementsAre(1.5, 3.5, 2.0));
8477

8578
// Check that E2 is unchanged
8679
EXPECT_THAT(E2, ElementsAre(0.5, 1.5, -1.0));
8780
}
8881

82+
TEST(IR2VecTest, SubtractVectors) {
83+
Embedding E1 = {1.0, 2.0, 3.0};
84+
Embedding E2 = {0.5, 1.5, -1.0};
85+
86+
E1 -= E2;
87+
EXPECT_THAT(E1, ElementsAre(0.5, 0.5, 4.0));
88+
89+
// Check that E2 is unchanged
90+
EXPECT_THAT(E2, ElementsAre(0.5, 1.5, -1.0));
91+
}
92+
8993
TEST(IR2VecTest, AddScaledVector) {
9094
Embedding E1 = {1.0, 2.0, 3.0};
9195
Embedding E2 = {2.0, 0.5, -1.0};
9296

93-
TestableEmbedder::addScaledVector(E1, E2, 0.5f);
97+
E1.scaleAndAdd(E2, 0.5f);
9498
EXPECT_THAT(E1, ElementsAre(2.0, 2.25, 2.5));
9599

96100
// Check that E2 is unchanged
97101
EXPECT_THAT(E2, ElementsAre(2.0, 0.5, -1.0));
98102
}
99103

104+
TEST(IR2VecTest, ApproximatelyEqual) {
105+
Embedding E1 = {1.0, 2.0, 3.0};
106+
Embedding E2 = {1.0000001, 2.0000001, 3.0000001};
107+
EXPECT_TRUE(E1.approximatelyEquals(E2)); // Diff = 1e-7
108+
109+
Embedding E3 = {1.00002, 2.00002, 3.00002}; // Diff = 2e-5
110+
EXPECT_FALSE(E1.approximatelyEquals(E3));
111+
EXPECT_TRUE(E1.approximatelyEquals(E3, 3e-5));
112+
113+
Embedding E_clearly_within = {1.0000005, 2.0000005, 3.0000005}; // Diff = 5e-7
114+
EXPECT_TRUE(E1.approximatelyEquals(E_clearly_within));
115+
116+
Embedding E_clearly_outside = {1.00001, 2.00001, 3.00001}; // Diff = 1e-5
117+
EXPECT_FALSE(E1.approximatelyEquals(E_clearly_outside));
118+
119+
Embedding E4 = {1.0, 2.0, 3.5}; // Large diff
120+
EXPECT_FALSE(E1.approximatelyEquals(E4, 0.01));
121+
122+
Embedding E5 = {1.0, 2.0, 3.0};
123+
EXPECT_TRUE(E1.approximatelyEquals(E5, 0.0));
124+
EXPECT_TRUE(E1.approximatelyEquals(E5));
125+
}
126+
100127
#if GTEST_HAS_DEATH_TEST
101128
#ifndef NDEBUG
102129
TEST(IR2VecTest, MismatchedDimensionsAddVectors) {
103130
Embedding E1 = {1.0, 2.0};
104131
Embedding E2 = {1.0};
105-
EXPECT_DEATH(TestableEmbedder::addVectors(E1, E2),
106-
"Vectors must have the same dimension");
132+
EXPECT_DEATH(E1 += E2, "Vectors must have the same dimension");
133+
}
134+
135+
TEST(IR2VecTest, MismatchedDimensionsSubtractVectors) {
136+
Embedding E1 = {1.0, 2.0};
137+
Embedding E2 = {1.0};
138+
EXPECT_DEATH(E1 -= E2, "Vectors must have the same dimension");
107139
}
108140

109141
TEST(IR2VecTest, MismatchedDimensionsAddScaledVector) {
110142
Embedding E1 = {1.0, 2.0};
111143
Embedding E2 = {1.0};
112-
EXPECT_DEATH(TestableEmbedder::addScaledVector(E1, E2, 1.0f),
144+
EXPECT_DEATH(E1.scaleAndAdd(E2, 1.0f),
145+
"Vectors must have the same dimension");
146+
}
147+
148+
TEST(IR2VecTest, MismatchedDimensionsApproximatelyEqual) {
149+
Embedding E1 = {1.0, 2.0};
150+
Embedding E2 = {1.010};
151+
EXPECT_DEATH(E1.approximatelyEquals(E2),
113152
"Vectors must have the same dimension");
114153
}
115154
#endif // NDEBUG
@@ -136,8 +175,9 @@ TEST(IR2VecTest, ZeroDimensionEmbedding) {
136175
Embedding E1;
137176
Embedding E2;
138177
// Should be no-op, but not crash
139-
TestableEmbedder::addVectors(E1, E2);
140-
TestableEmbedder::addScaledVector(E1, E2, 1.0f);
178+
E1 += E2;
179+
E1 -= E2;
180+
E1.scaleAndAdd(E2, 1.0f);
141181
EXPECT_TRUE(E1.empty());
142182
}
143183

0 commit comments

Comments
 (0)