@@ -281,25 +281,30 @@ TEST(IR2VecTest, IR2VecVocabResultValidity) {
281
281
EXPECT_EQ (validResult.getDimension (), 2u );
282
282
}
283
283
284
- // Helper to create a minimal function and embedder for getter tests
285
- struct GetterTestEnv {
286
- Vocab V = {};
284
+ // Fixture for IR2Vec tests requiring IR setup and weight management.
285
+ class IR2VecTestFixture : public ::testing::Test {
286
+ protected:
287
+ Vocab V;
287
288
LLVMContext Ctx;
288
- std::unique_ptr<Module> M = nullptr ;
289
+ std::unique_ptr<Module> M;
289
290
Function *F = nullptr ;
290
291
BasicBlock *BB = nullptr ;
291
- Instruction *Add = nullptr ;
292
- Instruction *Ret = nullptr ;
293
- std::unique_ptr<Embedder> Emb = nullptr ;
292
+ Instruction *AddInst = nullptr ;
293
+ Instruction *RetInst = nullptr ;
294
294
295
- GetterTestEnv () {
295
+ float OriginalOpcWeight = ::OpcWeight;
296
+ float OriginalTypeWeight = ::TypeWeight;
297
+ float OriginalArgWeight = ::ArgWeight;
298
+
299
+ void SetUp () override {
296
300
V = {{" add" , {1.0 , 2.0 }},
297
301
{" integerTy" , {0.5 , 0.5 }},
298
302
{" constant" , {0.2 , 0.3 }},
299
303
{" variable" , {0.0 , 0.0 }},
300
304
{" unknownTy" , {0.0 , 0.0 }}};
301
305
302
- M = std::make_unique<Module>(" M" , Ctx);
306
+ // Setup IR
307
+ M = std::make_unique<Module>(" TestM" , Ctx);
303
308
FunctionType *FTy = FunctionType::get (
304
309
Type::getInt32Ty (Ctx), {Type::getInt32Ty (Ctx), Type::getInt32Ty (Ctx)},
305
310
false );
@@ -308,61 +313,82 @@ struct GetterTestEnv {
308
313
Argument *Arg = F->getArg (0 );
309
314
llvm::Value *Const = ConstantInt::get (Type::getInt32Ty (Ctx), 42 );
310
315
311
- Add = BinaryOperator::CreateAdd (Arg, Const, " add" , BB);
312
- Ret = ReturnInst::Create (Ctx, Add, BB);
316
+ AddInst = BinaryOperator::CreateAdd (Arg, Const, " add" , BB);
317
+ RetInst = ReturnInst::Create (Ctx, AddInst, BB);
318
+ }
319
+
320
+ void setWeights (float OpcWeight, float TypeWeight, float ArgWeight) {
321
+ ::OpcWeight = OpcWeight;
322
+ ::TypeWeight = TypeWeight;
323
+ ::ArgWeight = ArgWeight;
324
+ }
313
325
314
- auto Result = Embedder::create (IR2VecKind::Symbolic, *F, V);
315
- EXPECT_TRUE (static_cast <bool >(Result));
316
- Emb = std::move (*Result);
326
+ void TearDown () override {
327
+ // Restore original global weights
328
+ ::OpcWeight = OriginalOpcWeight;
329
+ ::TypeWeight = OriginalTypeWeight;
330
+ ::ArgWeight = OriginalArgWeight;
317
331
}
318
332
};
319
333
320
- TEST (IR2VecTest, GetInstVecMap) {
321
- GetterTestEnv Env;
322
- const auto &InstMap = Env.Emb ->getInstVecMap ();
334
+ TEST_F (IR2VecTestFixture, GetInstVecMap) {
335
+ auto Result = Embedder::create (IR2VecKind::Symbolic, *F, V);
336
+ ASSERT_TRUE (static_cast <bool >(Result));
337
+ auto Emb = std::move (*Result);
338
+
339
+ const auto &InstMap = Emb->getInstVecMap ();
323
340
324
341
EXPECT_EQ (InstMap.size (), 2u );
325
- EXPECT_TRUE (InstMap.count (Env. Add ));
326
- EXPECT_TRUE (InstMap.count (Env. Ret ));
342
+ EXPECT_TRUE (InstMap.count (AddInst ));
343
+ EXPECT_TRUE (InstMap.count (RetInst ));
327
344
328
- EXPECT_EQ (InstMap.at (Env. Add ).size (), 2u );
329
- EXPECT_EQ (InstMap.at (Env. Ret ).size (), 2u );
345
+ EXPECT_EQ (InstMap.at (AddInst ).size (), 2u );
346
+ EXPECT_EQ (InstMap.at (RetInst ).size (), 2u );
330
347
331
348
// Check values for add: {1.29, 2.31}
332
- EXPECT_THAT (InstMap.at (Env. Add ),
349
+ EXPECT_THAT (InstMap.at (AddInst ),
333
350
ElementsAre (DoubleNear (1.29 , 1e-6 ), DoubleNear (2.31 , 1e-6 )));
334
351
335
352
// Check values for ret: {0.0, 0.}; Neither ret nor voidTy are present in
336
353
// vocab
337
- EXPECT_THAT (InstMap.at (Env. Ret ), ElementsAre (0.0 , 0.0 ));
354
+ EXPECT_THAT (InstMap.at (RetInst ), ElementsAre (0.0 , 0.0 ));
338
355
}
339
356
340
- TEST (IR2VecTest, GetBBVecMap) {
341
- GetterTestEnv Env;
342
- const auto &BBMap = Env.Emb ->getBBVecMap ();
357
+ TEST_F (IR2VecTestFixture, GetBBVecMap) {
358
+ auto Result = Embedder::create (IR2VecKind::Symbolic, *F, V);
359
+ ASSERT_TRUE (static_cast <bool >(Result));
360
+ auto Emb = std::move (*Result);
361
+
362
+ const auto &BBMap = Emb->getBBVecMap ();
343
363
344
364
EXPECT_EQ (BBMap.size (), 1u );
345
- EXPECT_TRUE (BBMap.count (Env. BB ));
346
- EXPECT_EQ (BBMap.at (Env. BB ).size (), 2u );
365
+ EXPECT_TRUE (BBMap.count (BB));
366
+ EXPECT_EQ (BBMap.at (BB).size (), 2u );
347
367
348
368
// BB vector should be sum of add and ret: {1.29, 2.31} + {0.0, 0.0} =
349
369
// {1.29, 2.31}
350
- EXPECT_THAT (BBMap.at (Env. BB ),
370
+ EXPECT_THAT (BBMap.at (BB),
351
371
ElementsAre (DoubleNear (1.29 , 1e-6 ), DoubleNear (2.31 , 1e-6 )));
352
372
}
353
373
354
- TEST (IR2VecTest, GetBBVector) {
355
- GetterTestEnv Env;
356
- const auto &BBVec = Env.Emb ->getBBVector (*Env.BB );
374
+ TEST_F (IR2VecTestFixture, GetBBVector) {
375
+ auto Result = Embedder::create (IR2VecKind::Symbolic, *F, V);
376
+ ASSERT_TRUE (static_cast <bool >(Result));
377
+ auto Emb = std::move (*Result);
378
+
379
+ const auto &BBVec = Emb->getBBVector (*BB);
357
380
358
381
EXPECT_EQ (BBVec.size (), 2u );
359
382
EXPECT_THAT (BBVec,
360
383
ElementsAre (DoubleNear (1.29 , 1e-6 ), DoubleNear (2.31 , 1e-6 )));
361
384
}
362
385
363
- TEST (IR2VecTest, GetFunctionVector) {
364
- GetterTestEnv Env;
365
- const auto &FuncVec = Env.Emb ->getFunctionVector ();
386
+ TEST_F (IR2VecTestFixture, GetFunctionVector) {
387
+ auto Result = Embedder::create (IR2VecKind::Symbolic, *F, V);
388
+ ASSERT_TRUE (static_cast <bool >(Result));
389
+ auto Emb = std::move (*Result);
390
+
391
+ const auto &FuncVec = Emb->getFunctionVector ();
366
392
367
393
EXPECT_EQ (FuncVec.size (), 2u );
368
394
@@ -371,4 +397,45 @@ TEST(IR2VecTest, GetFunctionVector) {
371
397
ElementsAre (DoubleNear (1.29 , 1e-6 ), DoubleNear (2.31 , 1e-6 )));
372
398
}
373
399
400
+ TEST_F (IR2VecTestFixture, GetFunctionVectorWithCustomWeights) {
401
+ setWeights (1.0 , 1.0 , 1.0 );
402
+
403
+ auto Result = Embedder::create (IR2VecKind::Symbolic, *F, V);
404
+ ASSERT_TRUE (static_cast <bool >(Result));
405
+ auto Emb = std::move (*Result);
406
+
407
+ const auto &FuncVec = Emb->getFunctionVector ();
408
+
409
+ EXPECT_EQ (FuncVec.size (), 2u );
410
+
411
+ // Expected: 1*([1.0 2.0] + [0.0 0.0]) + 1*([0.5 0.5] + [0.0 0.0]) + 1*([0.2
412
+ // 0.3] + [0.0 0.0])
413
+ EXPECT_THAT (FuncVec,
414
+ ElementsAre (DoubleNear (1.7 , 1e-6 ), DoubleNear (2.8 , 1e-6 )));
415
+ }
416
+
417
+ TEST (IR2VecTest, IR2VecVocabAnalysisWithPrepopulatedVocab) {
418
+ Vocab InitialVocab = {{" key1" , {1.1 , 2.2 }}, {" key2" , {3.3 , 4.4 }}};
419
+ Vocab ExpectedVocab = InitialVocab;
420
+ unsigned ExpectedDim = InitialVocab.begin ()->second .size ();
421
+
422
+ IR2VecVocabAnalysis VocabAnalysis (std::move (InitialVocab));
423
+
424
+ LLVMContext TestCtx;
425
+ Module TestMod (" TestModuleForVocabAnalysis" , TestCtx);
426
+ ModuleAnalysisManager MAM;
427
+ IR2VecVocabResult Result = VocabAnalysis.run (TestMod, MAM);
428
+
429
+ EXPECT_TRUE (Result.isValid ());
430
+ ASSERT_FALSE (Result.getVocabulary ().empty ());
431
+ EXPECT_EQ (Result.getDimension (), ExpectedDim);
432
+
433
+ const auto &ResultVocab = Result.getVocabulary ();
434
+ EXPECT_EQ (ResultVocab.size (), ExpectedVocab.size ());
435
+ for (const auto &pair : ExpectedVocab) {
436
+ EXPECT_TRUE (ResultVocab.count (pair.first ));
437
+ EXPECT_THAT (ResultVocab.at (pair.first ), ElementsAreArray (pair.second ));
438
+ }
439
+ }
440
+
374
441
} // end anonymous namespace
0 commit comments