@@ -55,6 +55,51 @@ static cl::opt<float> ArgWeight("ir2vec-arg-weight", cl::Optional,
55
55
56
56
AnalysisKey IR2VecVocabAnalysis::Key;
57
57
58
+ namespace llvm ::json {
59
+ inline bool fromJSON (const llvm::json::Value &E, Embedding &Out,
60
+ llvm::json::Path P) {
61
+ std::vector<double > TempOut;
62
+ if (!llvm::json::fromJSON (E, TempOut, P))
63
+ return false ;
64
+ Out = Embedding (std::move (TempOut));
65
+ return true ;
66
+ }
67
+ } // namespace llvm::json
68
+
69
+ // ==----------------------------------------------------------------------===//
70
+ // Embedding
71
+ // ===----------------------------------------------------------------------===//
72
+
73
+ Embedding &Embedding::operator +=(const Embedding &RHS) {
74
+ assert (this ->size () == RHS.size () && " Vectors must have the same dimension" );
75
+ std::transform (this ->begin (), this ->end (), RHS.begin (), this ->begin (),
76
+ std::plus<double >());
77
+ return *this ;
78
+ }
79
+
80
+ Embedding &Embedding::operator -=(const Embedding &RHS) {
81
+ assert (this ->size () == RHS.size () && " Vectors must have the same dimension" );
82
+ std::transform (this ->begin (), this ->end (), RHS.begin (), this ->begin (),
83
+ std::minus<double >());
84
+ return *this ;
85
+ }
86
+
87
+ Embedding &Embedding::scaleAndAdd (const Embedding &Src, float Factor) {
88
+ assert (this ->size () == Src.size () && " Vectors must have the same dimension" );
89
+ for (size_t Itr = 0 ; Itr < this ->size (); ++Itr)
90
+ (*this )[Itr] += Src[Itr] * Factor;
91
+ return *this ;
92
+ }
93
+
94
+ bool Embedding::approximatelyEquals (const Embedding &RHS,
95
+ double Tolerance) const {
96
+ assert (this ->size () == RHS.size () && " Vectors must have the same dimension" );
97
+ for (size_t Itr = 0 ; Itr < this ->size (); ++Itr)
98
+ if (std::abs ((*this )[Itr] - RHS[Itr]) > Tolerance)
99
+ return false ;
100
+ return true ;
101
+ }
102
+
58
103
// ==----------------------------------------------------------------------===//
59
104
// Embedder and its subclasses
60
105
// ===----------------------------------------------------------------------===//
@@ -73,20 +118,6 @@ Embedder::create(IR2VecKind Mode, const Function &F, const Vocab &Vocabulary) {
73
118
return make_error<StringError>(" Unknown IR2VecKind" , errc::invalid_argument);
74
119
}
75
120
76
- void Embedder::addVectors (Embedding &Dst, const Embedding &Src) {
77
- assert (Dst.size () == Src.size () && " Vectors must have the same dimension" );
78
- std::transform (Dst.begin (), Dst.end (), Src.begin (), Dst.begin (),
79
- std::plus<double >());
80
- }
81
-
82
- void Embedder::addScaledVector (Embedding &Dst, const Embedding &Src,
83
- float Factor) {
84
- assert (Dst.size () == Src.size () && " Vectors must have the same dimension" );
85
- for (size_t i = 0 ; i < Dst.size (); ++i) {
86
- Dst[i] += Src[i] * Factor;
87
- }
88
- }
89
-
90
121
// FIXME: Currently lookups are string based. Use numeric Keys
91
122
// for efficiency
92
123
Embedding Embedder::lookupVocab (const std::string &Key) const {
@@ -164,20 +195,20 @@ void SymbolicEmbedder::computeEmbeddings(const BasicBlock &BB) const {
164
195
Embedding InstVector (Dimension, 0 );
165
196
166
197
const auto OpcVec = lookupVocab (I.getOpcodeName ());
167
- addScaledVector ( InstVector, OpcVec, OpcWeight);
198
+ InstVector. scaleAndAdd ( OpcVec, OpcWeight);
168
199
169
200
// FIXME: Currently lookups are string based. Use numeric Keys
170
201
// for efficiency.
171
202
const auto Type = I.getType ();
172
203
const auto TypeVec = getTypeEmbedding (Type);
173
- addScaledVector ( InstVector, TypeVec, TypeWeight);
204
+ InstVector. scaleAndAdd ( TypeVec, TypeWeight);
174
205
175
206
for (const auto &Op : I.operands ()) {
176
207
const auto OperandVec = getOperandEmbedding (Op.get ());
177
- addScaledVector ( InstVector, OperandVec, ArgWeight);
208
+ InstVector. scaleAndAdd ( OperandVec, ArgWeight);
178
209
}
179
210
InstVecMap[&I] = InstVector;
180
- addVectors ( BBVector, InstVector) ;
211
+ BBVector += InstVector;
181
212
}
182
213
BBVecMap[&BB] = BBVector;
183
214
}
@@ -187,7 +218,7 @@ void SymbolicEmbedder::computeEmbeddings() const {
187
218
return ;
188
219
for (const auto &BB : F) {
189
220
computeEmbeddings (BB);
190
- addVectors ( FuncVector, BBVecMap[&BB]) ;
221
+ FuncVector += BBVecMap[&BB];
191
222
}
192
223
}
193
224
0 commit comments