@@ -199,25 +199,30 @@ TEST(IR2VecTest, IR2VecVocabResultValidity) {
199
199
EXPECT_EQ (validResult.getDimension (), 2u );
200
200
}
201
201
202
- // Helper to create a minimal function and embedder for getter tests
203
- struct GetterTestEnv {
204
- Vocab V = {};
202
+ // Fixture for IR2Vec tests requiring IR setup and weight management.
203
+ class IR2VecTestFixture : public ::testing::Test {
204
+ protected:
205
+ Vocab V;
205
206
LLVMContext Ctx;
206
- std::unique_ptr<Module> M = nullptr ;
207
+ std::unique_ptr<Module> M;
207
208
Function *F = nullptr ;
208
209
BasicBlock *BB = nullptr ;
209
- Instruction *Add = nullptr ;
210
- Instruction *Ret = nullptr ;
211
- std::unique_ptr<Embedder> Emb = nullptr ;
210
+ Instruction *AddInst = nullptr ;
211
+ Instruction *RetInst = nullptr ;
212
212
213
- GetterTestEnv () {
213
+ float OriginalOpcWeight = ::OpcWeight;
214
+ float OriginalTypeWeight = ::TypeWeight;
215
+ float OriginalArgWeight = ::ArgWeight;
216
+
217
+ void SetUp () override {
214
218
V = {{" add" , {1.0 , 2.0 }},
215
219
{" integerTy" , {0.5 , 0.5 }},
216
220
{" constant" , {0.2 , 0.3 }},
217
221
{" variable" , {0.0 , 0.0 }},
218
222
{" unknownTy" , {0.0 , 0.0 }}};
219
223
220
- M = std::make_unique<Module>(" M" , Ctx);
224
+ // Setup IR
225
+ M = std::make_unique<Module>(" TestM" , Ctx);
221
226
FunctionType *FTy = FunctionType::get (
222
227
Type::getInt32Ty (Ctx), {Type::getInt32Ty (Ctx), Type::getInt32Ty (Ctx)},
223
228
false );
@@ -226,61 +231,82 @@ struct GetterTestEnv {
226
231
Argument *Arg = F->getArg (0 );
227
232
llvm::Value *Const = ConstantInt::get (Type::getInt32Ty (Ctx), 42 );
228
233
229
- Add = BinaryOperator::CreateAdd (Arg, Const, " add" , BB);
230
- Ret = ReturnInst::Create (Ctx, Add, BB);
234
+ AddInst = BinaryOperator::CreateAdd (Arg, Const, " add" , BB);
235
+ RetInst = ReturnInst::Create (Ctx, AddInst, BB);
236
+ }
237
+
238
+ void setWeights (float OpcWeight, float TypeWeight, float ArgWeight) {
239
+ ::OpcWeight = OpcWeight;
240
+ ::TypeWeight = TypeWeight;
241
+ ::ArgWeight = ArgWeight;
242
+ }
231
243
232
- auto Result = Embedder::create (IR2VecKind::Symbolic, *F, V);
233
- EXPECT_TRUE (static_cast <bool >(Result));
234
- Emb = std::move (*Result);
244
+ void TearDown () override {
245
+ // Restore original global weights
246
+ ::OpcWeight = OriginalOpcWeight;
247
+ ::TypeWeight = OriginalTypeWeight;
248
+ ::ArgWeight = OriginalArgWeight;
235
249
}
236
250
};
237
251
238
- TEST (IR2VecTest, GetInstVecMap) {
239
- GetterTestEnv Env;
240
- const auto &InstMap = Env.Emb ->getInstVecMap ();
252
+ TEST_F (IR2VecTestFixture, GetInstVecMap) {
253
+ auto Result = Embedder::create (IR2VecKind::Symbolic, *F, V);
254
+ ASSERT_TRUE (static_cast <bool >(Result));
255
+ auto Emb = std::move (*Result);
256
+
257
+ const auto &InstMap = Emb->getInstVecMap ();
241
258
242
259
EXPECT_EQ (InstMap.size (), 2u );
243
- EXPECT_TRUE (InstMap.count (Env. Add ));
244
- EXPECT_TRUE (InstMap.count (Env. Ret ));
260
+ EXPECT_TRUE (InstMap.count (AddInst ));
261
+ EXPECT_TRUE (InstMap.count (RetInst ));
245
262
246
- EXPECT_EQ (InstMap.at (Env. Add ).size (), 2u );
247
- EXPECT_EQ (InstMap.at (Env. Ret ).size (), 2u );
263
+ EXPECT_EQ (InstMap.at (AddInst ).size (), 2u );
264
+ EXPECT_EQ (InstMap.at (RetInst ).size (), 2u );
248
265
249
266
// Check values for add: {1.29, 2.31}
250
- EXPECT_THAT (InstMap.at (Env. Add ),
267
+ EXPECT_THAT (InstMap.at (AddInst ),
251
268
ElementsAre (DoubleNear (1.29 , 1e-6 ), DoubleNear (2.31 , 1e-6 )));
252
269
253
270
// Check values for ret: {0.0, 0.}; Neither ret nor voidTy are present in
254
271
// vocab
255
- EXPECT_THAT (InstMap.at (Env. Ret ), ElementsAre (0.0 , 0.0 ));
272
+ EXPECT_THAT (InstMap.at (RetInst ), ElementsAre (0.0 , 0.0 ));
256
273
}
257
274
258
- TEST (IR2VecTest, GetBBVecMap) {
259
- GetterTestEnv Env;
260
- const auto &BBMap = Env.Emb ->getBBVecMap ();
275
+ TEST_F (IR2VecTestFixture, GetBBVecMap) {
276
+ auto Result = Embedder::create (IR2VecKind::Symbolic, *F, V);
277
+ ASSERT_TRUE (static_cast <bool >(Result));
278
+ auto Emb = std::move (*Result);
279
+
280
+ const auto &BBMap = Emb->getBBVecMap ();
261
281
262
282
EXPECT_EQ (BBMap.size (), 1u );
263
- EXPECT_TRUE (BBMap.count (Env. BB ));
264
- EXPECT_EQ (BBMap.at (Env. BB ).size (), 2u );
283
+ EXPECT_TRUE (BBMap.count (BB));
284
+ EXPECT_EQ (BBMap.at (BB).size (), 2u );
265
285
266
286
// BB vector should be sum of add and ret: {1.29, 2.31} + {0.0, 0.0} =
267
287
// {1.29, 2.31}
268
- EXPECT_THAT (BBMap.at (Env. BB ),
288
+ EXPECT_THAT (BBMap.at (BB),
269
289
ElementsAre (DoubleNear (1.29 , 1e-6 ), DoubleNear (2.31 , 1e-6 )));
270
290
}
271
291
272
- TEST (IR2VecTest, GetBBVector) {
273
- GetterTestEnv Env;
274
- const auto &BBVec = Env.Emb ->getBBVector (*Env.BB );
292
+ TEST_F (IR2VecTestFixture, GetBBVector) {
293
+ auto Result = Embedder::create (IR2VecKind::Symbolic, *F, V);
294
+ ASSERT_TRUE (static_cast <bool >(Result));
295
+ auto Emb = std::move (*Result);
296
+
297
+ const auto &BBVec = Emb->getBBVector (*BB);
275
298
276
299
EXPECT_EQ (BBVec.size (), 2u );
277
300
EXPECT_THAT (BBVec,
278
301
ElementsAre (DoubleNear (1.29 , 1e-6 ), DoubleNear (2.31 , 1e-6 )));
279
302
}
280
303
281
- TEST (IR2VecTest, GetFunctionVector) {
282
- GetterTestEnv Env;
283
- const auto &FuncVec = Env.Emb ->getFunctionVector ();
304
+ TEST_F (IR2VecTestFixture, GetFunctionVector) {
305
+ auto Result = Embedder::create (IR2VecKind::Symbolic, *F, V);
306
+ ASSERT_TRUE (static_cast <bool >(Result));
307
+ auto Emb = std::move (*Result);
308
+
309
+ const auto &FuncVec = Emb->getFunctionVector ();
284
310
285
311
EXPECT_EQ (FuncVec.size (), 2u );
286
312
@@ -289,4 +315,45 @@ TEST(IR2VecTest, GetFunctionVector) {
289
315
ElementsAre (DoubleNear (1.29 , 1e-6 ), DoubleNear (2.31 , 1e-6 )));
290
316
}
291
317
318
+ TEST_F (IR2VecTestFixture, GetFunctionVectorWithCustomWeights) {
319
+ setWeights (1.0 , 1.0 , 1.0 );
320
+
321
+ auto Result = Embedder::create (IR2VecKind::Symbolic, *F, V);
322
+ ASSERT_TRUE (static_cast <bool >(Result));
323
+ auto Emb = std::move (*Result);
324
+
325
+ const auto &FuncVec = Emb->getFunctionVector ();
326
+
327
+ EXPECT_EQ (FuncVec.size (), 2u );
328
+
329
+ // Expected: 1*([1.0 2.0] + [0.0 0.0]) + 1*([0.5 0.5] + [0.0 0.0]) + 1*([0.2
330
+ // 0.3] + [0.0 0.0])
331
+ EXPECT_THAT (FuncVec,
332
+ ElementsAre (DoubleNear (1.7 , 1e-6 ), DoubleNear (2.8 , 1e-6 )));
333
+ }
334
+
335
+ TEST (IR2VecTest, IR2VecVocabAnalysisWithPrepopulatedVocab) {
336
+ Vocab InitialVocab = {{" key1" , {1.1 , 2.2 }}, {" key2" , {3.3 , 4.4 }}};
337
+ Vocab ExpectedVocab = InitialVocab;
338
+ unsigned ExpectedDim = InitialVocab.begin ()->second .size ();
339
+
340
+ IR2VecVocabAnalysis VocabAnalysis (std::move (InitialVocab));
341
+
342
+ LLVMContext TestCtx;
343
+ Module TestMod (" TestModuleForVocabAnalysis" , TestCtx);
344
+ ModuleAnalysisManager MAM;
345
+ IR2VecVocabResult Result = VocabAnalysis.run (TestMod, MAM);
346
+
347
+ EXPECT_TRUE (Result.isValid ());
348
+ ASSERT_FALSE (Result.getVocabulary ().empty ());
349
+ EXPECT_EQ (Result.getDimension (), ExpectedDim);
350
+
351
+ const auto &ResultVocab = Result.getVocabulary ();
352
+ EXPECT_EQ (ResultVocab.size (), ExpectedVocab.size ());
353
+ for (const auto &pair : ExpectedVocab) {
354
+ EXPECT_TRUE (ResultVocab.count (pair.first ));
355
+ EXPECT_THAT (ResultVocab.at (pair.first ), ElementsAreArray (pair.second ));
356
+ }
357
+ }
358
+
292
359
} // end anonymous namespace
0 commit comments