Skip to content

Commit 9c05884

Browse files
committed
Embedding
1 parent 34b985f commit 9c05884

File tree

3 files changed

+261
-74
lines changed

3 files changed

+261
-74
lines changed

llvm/include/llvm/Analysis/IR2Vec.h

Lines changed: 59 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -53,16 +53,72 @@ class raw_ostream;
5353
enum class IR2VecKind { Symbolic };
5454

5555
namespace ir2vec {
56-
using Embedding = std::vector<double>;
56+
/// Embedding is a ADT that wraps std::vector<double>. It provides
57+
/// additional functionality for arithmetic and comparison operations.
58+
/// It is meant to be used *like* std::vector<double> but is more restrictive
59+
/// in the sense that it does not allow the user to change the size of the
60+
/// embedding vector. The dimension of the embedding is fixed at the time of
61+
/// construction of Embedding object. But the elements can be modified in-place.
62+
struct Embedding {
63+
private:
64+
std::vector<double> Data;
65+
66+
public:
67+
Embedding() = default;
68+
Embedding(const std::vector<double> &V) : Data(V) {}
69+
Embedding(std::vector<double> &&V) : Data(std::move(V)) {}
70+
Embedding(std::initializer_list<double> IL) : Data(IL) {}
71+
72+
explicit Embedding(size_t Size) : Data(Size) {}
73+
Embedding(size_t Size, double InitialValue) : Data(Size, InitialValue) {}
74+
75+
size_t size() const { return Data.size(); }
76+
bool empty() const { return Data.empty(); }
77+
78+
double &operator[](size_t Itr) {
79+
assert(Itr < Data.size() && "Index out of bounds");
80+
return Data[Itr];
81+
}
82+
83+
const double &operator[](size_t Itr) const {
84+
assert(Itr < Data.size() && "Index out of bounds");
85+
return Data[Itr];
86+
}
87+
88+
using iterator = typename std::vector<double>::iterator;
89+
using const_iterator = typename std::vector<double>::const_iterator;
90+
91+
iterator begin() { return Data.begin(); }
92+
iterator end() { return Data.end(); }
93+
const_iterator begin() const { return Data.begin(); }
94+
const_iterator end() const { return Data.end(); }
95+
const_iterator cbegin() const { return Data.cbegin(); }
96+
const_iterator cend() const { return Data.cend(); }
97+
98+
const std::vector<double> &getData() const { return Data; }
99+
100+
/// Arithmetic operators
101+
Embedding &operator+=(const Embedding &RHS);
102+
Embedding &operator-=(const Embedding &RHS);
103+
104+
/// Adds Src Embedding scaled by Factor with the called Embedding.
105+
/// Called_Embedding += Src * Factor
106+
Embedding &scaleAndAdd(const Embedding &Src, float Factor);
107+
108+
/// Returns true if the embedding is approximately equal to the RHS embedding
109+
/// within the specified tolerance.
110+
bool approximatelyEquals(const Embedding &RHS, double Tolerance = 1e-6) const;
111+
};
112+
57113
using InstEmbeddingsMap = DenseMap<const Instruction *, Embedding>;
58114
using BBEmbeddingsMap = DenseMap<const BasicBlock *, Embedding>;
59115
// FIXME: Current the keys are strings. This can be changed to
60116
// use integers for cheaper lookups.
61117
using Vocab = std::map<std::string, Embedding>;
62118

63119
/// Embedder provides the interface to generate embeddings (vector
64-
/// representations) for instructions, basic blocks, and functions. The vector
65-
/// representations are generated using IR2Vec algorithms.
120+
/// representations) for instructions, basic blocks, and functions. The
121+
/// vector representations are generated using IR2Vec algorithms.
66122
///
67123
/// The Embedder class is an abstract class and it is intended to be
68124
/// subclassed for different IR2Vec algorithms like Symbolic and Flow-aware.
@@ -99,13 +155,6 @@ class Embedder {
99155
/// zero vector.
100156
Embedding lookupVocab(const std::string &Key) const;
101157

102-
/// Adds two vectors: Dst += Src
103-
static void addVectors(Embedding &Dst, const Embedding &Src);
104-
105-
/// Adds Src vector scaled by Factor to Dst vector: Dst += Src * Factor
106-
static void addScaledVector(Embedding &Dst, const Embedding &Src,
107-
float Factor);
108-
109158
public:
110159
virtual ~Embedder() = default;
111160

llvm/lib/Analysis/IR2Vec.cpp

Lines changed: 50 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,51 @@ static cl::opt<float> ArgWeight("ir2vec-arg-weight", cl::Optional,
5555

5656
AnalysisKey IR2VecVocabAnalysis::Key;
5757

58+
namespace llvm::json {
59+
inline bool fromJSON(const llvm::json::Value &E, Embedding &Out,
60+
llvm::json::Path P) {
61+
std::vector<double> TempOut;
62+
if (!llvm::json::fromJSON(E, TempOut, P))
63+
return false;
64+
Out = Embedding(std::move(TempOut));
65+
return true;
66+
}
67+
} // namespace llvm::json
68+
69+
// ==----------------------------------------------------------------------===//
70+
// Embedding
71+
//===----------------------------------------------------------------------===//
72+
73+
Embedding &Embedding::operator+=(const Embedding &RHS) {
74+
assert(this->size() == RHS.size() && "Vectors must have the same dimension");
75+
std::transform(this->begin(), this->end(), RHS.begin(), this->begin(),
76+
std::plus<double>());
77+
return *this;
78+
}
79+
80+
Embedding &Embedding::operator-=(const Embedding &RHS) {
81+
assert(this->size() == RHS.size() && "Vectors must have the same dimension");
82+
std::transform(this->begin(), this->end(), RHS.begin(), this->begin(),
83+
std::minus<double>());
84+
return *this;
85+
}
86+
87+
Embedding &Embedding::scaleAndAdd(const Embedding &Src, float Factor) {
88+
assert(this->size() == Src.size() && "Vectors must have the same dimension");
89+
for (size_t Itr = 0; Itr < this->size(); ++Itr)
90+
(*this)[Itr] += Src[Itr] * Factor;
91+
return *this;
92+
}
93+
94+
bool Embedding::approximatelyEquals(const Embedding &RHS,
95+
double Tolerance) const {
96+
assert(this->size() == RHS.size() && "Vectors must have the same dimension");
97+
for (size_t Itr = 0; Itr < this->size(); ++Itr)
98+
if (std::abs((*this)[Itr] - RHS[Itr]) > Tolerance)
99+
return false;
100+
return true;
101+
}
102+
58103
// ==----------------------------------------------------------------------===//
59104
// Embedder and its subclasses
60105
//===----------------------------------------------------------------------===//
@@ -73,20 +118,6 @@ Embedder::create(IR2VecKind Mode, const Function &F, const Vocab &Vocabulary) {
73118
return make_error<StringError>("Unknown IR2VecKind", errc::invalid_argument);
74119
}
75120

76-
void Embedder::addVectors(Embedding &Dst, const Embedding &Src) {
77-
assert(Dst.size() == Src.size() && "Vectors must have the same dimension");
78-
std::transform(Dst.begin(), Dst.end(), Src.begin(), Dst.begin(),
79-
std::plus<double>());
80-
}
81-
82-
void Embedder::addScaledVector(Embedding &Dst, const Embedding &Src,
83-
float Factor) {
84-
assert(Dst.size() == Src.size() && "Vectors must have the same dimension");
85-
for (size_t i = 0; i < Dst.size(); ++i) {
86-
Dst[i] += Src[i] * Factor;
87-
}
88-
}
89-
90121
// FIXME: Currently lookups are string based. Use numeric Keys
91122
// for efficiency
92123
Embedding Embedder::lookupVocab(const std::string &Key) const {
@@ -164,20 +195,20 @@ void SymbolicEmbedder::computeEmbeddings(const BasicBlock &BB) const {
164195
Embedding InstVector(Dimension, 0);
165196

166197
const auto OpcVec = lookupVocab(I.getOpcodeName());
167-
addScaledVector(InstVector, OpcVec, OpcWeight);
198+
InstVector.scaleAndAdd(OpcVec, OpcWeight);
168199

169200
// FIXME: Currently lookups are string based. Use numeric Keys
170201
// for efficiency.
171202
const auto Type = I.getType();
172203
const auto TypeVec = getTypeEmbedding(Type);
173-
addScaledVector(InstVector, TypeVec, TypeWeight);
204+
InstVector.scaleAndAdd(TypeVec, TypeWeight);
174205

175206
for (const auto &Op : I.operands()) {
176207
const auto OperandVec = getOperandEmbedding(Op.get());
177-
addScaledVector(InstVector, OperandVec, ArgWeight);
208+
InstVector.scaleAndAdd(OperandVec, ArgWeight);
178209
}
179210
InstVecMap[&I] = InstVector;
180-
addVectors(BBVector, InstVector);
211+
BBVector += InstVector;
181212
}
182213
BBVecMap[&BB] = BBVector;
183214
}
@@ -187,7 +218,7 @@ void SymbolicEmbedder::computeEmbeddings() const {
187218
return;
188219
for (const auto &BB : F) {
189220
computeEmbeddings(BB);
190-
addVectors(FuncVector, BBVecMap[&BB]);
221+
FuncVector += BBVecMap[&BB];
191222
}
192223
}
193224

0 commit comments

Comments
 (0)