Skip to content

Commit 9124e83

Browse files
committed
[MLIniner][IR2Vec] Integrating IR2Vec with MLInliner
1 parent e29bb9a commit 9124e83

File tree

8 files changed

+338
-24
lines changed

8 files changed

+338
-24
lines changed

llvm/include/llvm/Analysis/FunctionPropertiesAnalysis.h

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#define LLVM_ANALYSIS_FUNCTIONPROPERTIESANALYSIS_H
1616

1717
#include "llvm/ADT/DenseSet.h"
18+
#include "llvm/Analysis/IR2Vec.h"
1819
#include "llvm/IR/Dominators.h"
1920
#include "llvm/IR/PassManager.h"
2021
#include "llvm/Support/Compiler.h"
@@ -32,17 +33,19 @@ class FunctionPropertiesInfo {
3233
void updateAggregateStats(const Function &F, const LoopInfo &LI);
3334
void reIncludeBB(const BasicBlock &BB);
3435

36+
ir2vec::Embedding FunctionEmbedding = ir2vec::Embedding(0.0);
37+
std::optional<ir2vec::Vocab> IR2VecVocab;
38+
3539
public:
3640
LLVM_ABI static FunctionPropertiesInfo
3741
getFunctionPropertiesInfo(const Function &F, const DominatorTree &DT,
38-
const LoopInfo &LI);
42+
const LoopInfo &LI,
43+
const IR2VecVocabResult *VocabResult);
3944

4045
LLVM_ABI static FunctionPropertiesInfo
4146
getFunctionPropertiesInfo(Function &F, FunctionAnalysisManager &FAM);
4247

43-
bool operator==(const FunctionPropertiesInfo &FPI) const {
44-
return std::memcmp(this, &FPI, sizeof(FunctionPropertiesInfo)) == 0;
45-
}
48+
bool operator==(const FunctionPropertiesInfo &FPI) const;
4649

4750
bool operator!=(const FunctionPropertiesInfo &FPI) const {
4851
return !(*this == FPI);
@@ -137,6 +140,19 @@ class FunctionPropertiesInfo {
137140
int64_t CallReturnsVectorPointerCount = 0;
138141
int64_t CallWithManyArgumentsCount = 0;
139142
int64_t CallWithPointerArgumentCount = 0;
143+
144+
const ir2vec::Embedding &getFunctionEmbedding() const {
145+
return FunctionEmbedding;
146+
}
147+
148+
const std::optional<ir2vec::Vocab> &getIR2VecVocab() const {
149+
return IR2VecVocab;
150+
}
151+
152+
// Helper intended to be useful for unittests
153+
void setFunctionEmbeddingForTest(const ir2vec::Embedding &Embedding) {
154+
FunctionEmbedding = Embedding;
155+
}
140156
};
141157

142158
// Analysis pass
@@ -192,7 +208,7 @@ class FunctionPropertiesUpdater {
192208

193209
DominatorTree &getUpdatedDominatorTree(FunctionAnalysisManager &FAM) const;
194210

195-
DenseSet<const BasicBlock *> Successors;
211+
DenseSet<const BasicBlock *> Successors, CallUsers;
196212

197213
// Edges we might potentially need to remove from the dominator tree.
198214
SmallVector<DominatorTree::UpdateType, 2> DomTreeUpdates;

llvm/include/llvm/Analysis/InlineAdvisor.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -331,6 +331,10 @@ class InlineAdvisorAnalysis : public AnalysisInfoMixin<InlineAdvisorAnalysis> {
331331
};
332332

333333
Result run(Module &M, ModuleAnalysisManager &MAM) { return Result(M, MAM); }
334+
335+
private:
336+
static bool initializeIR2VecVocabIfRequested(Module &M,
337+
ModuleAnalysisManager &MAM);
334338
};
335339

336340
/// Printer pass for the InlineAdvisorAnalysis results.

llvm/include/llvm/Analysis/InlineModelFeatureMaps.h

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,12 @@ enum class FeatureIndex : size_t {
142142
INLINE_FEATURE_ITERATOR(POPULATE_INDICES)
143143
#undef POPULATE_INDICES
144144

145+
// IR2Vec embeddings
146+
// Dimensions of embeddings are not known in the compile time (until vocab is
147+
// read). Hence macros cannot be used here.
148+
callee_embedding,
149+
caller_embedding,
150+
145151
NumberOfFeatures
146152
};
147153
// clang-format on
@@ -154,7 +160,7 @@ inlineCostFeatureToMlFeature(InlineCostFeatureIndex Feature) {
154160
constexpr size_t NumberOfFeatures =
155161
static_cast<size_t>(FeatureIndex::NumberOfFeatures);
156162

157-
LLVM_ABI extern const std::vector<TensorSpec> FeatureMap;
163+
LLVM_ABI extern std::vector<TensorSpec> FeatureMap;
158164

159165
LLVM_ABI extern const char *const DecisionName;
160166
LLVM_ABI extern const TensorSpec InlineDecisionSpec;

llvm/include/llvm/Analysis/MLInlineAdvisor.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ class MLInlineAdvisor : public InlineAdvisor {
8282
int64_t NodeCount = 0;
8383
int64_t EdgeCount = 0;
8484
int64_t EdgesOfLastSeenNodes = 0;
85+
const bool UseIR2Vec;
8586

8687
std::map<const LazyCallGraph::Node *, unsigned> FunctionLevels;
8788
const int32_t InitialIRSize = 0;

llvm/lib/Analysis/FunctionPropertiesAnalysis.cpp

Lines changed: 112 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,29 @@ void FunctionPropertiesInfo::updateForBB(const BasicBlock &BB,
199199
#undef CHECK_OPERAND
200200
}
201201
}
202+
203+
if (IR2VecVocab) {
204+
// We instantiate the IR2Vec embedder each time, as having an unique
205+
// pointer to the embedder as member of the class would make it
206+
// non-copyable. Instantiating the embedder in itself is not costly.
207+
auto EmbOrErr = ir2vec::Embedder::create(IR2VecKind::Symbolic,
208+
*BB.getParent(), *IR2VecVocab);
209+
if (Error Err = EmbOrErr.takeError()) {
210+
handleAllErrors(std::move(Err), [&](const ErrorInfoBase &EI) {
211+
BB.getContext().emitError("Error creating IR2Vec embeddings: " +
212+
EI.message());
213+
});
214+
return;
215+
}
216+
auto Embedder = std::move(*EmbOrErr);
217+
const auto &BBEmbedding = Embedder->getBBVector(BB);
218+
// Subtract BBEmbedding from Function embedding if the direction is -1,
219+
// and add it if the direction is +1.
220+
if (Direction == -1)
221+
FunctionEmbedding -= BBEmbedding;
222+
else
223+
FunctionEmbedding += BBEmbedding;
224+
}
202225
}
203226

204227
void FunctionPropertiesInfo::updateAggregateStats(const Function &F,
@@ -220,21 +243,91 @@ void FunctionPropertiesInfo::updateAggregateStats(const Function &F,
220243

221244
FunctionPropertiesInfo FunctionPropertiesInfo::getFunctionPropertiesInfo(
222245
Function &F, FunctionAnalysisManager &FAM) {
246+
// We use the cached result of the IR2VecVocabAnalysis run by
247+
// InlineAdvisorAnalysis. If the IR2VecVocabAnalysis is not run, we don't
248+
// use IR2Vec embeddings.
249+
auto VocabResult = FAM.getResult<ModuleAnalysisManagerFunctionProxy>(F)
250+
.getCachedResult<IR2VecVocabAnalysis>(*F.getParent());
223251
return getFunctionPropertiesInfo(F, FAM.getResult<DominatorTreeAnalysis>(F),
224-
FAM.getResult<LoopAnalysis>(F));
252+
FAM.getResult<LoopAnalysis>(F), VocabResult);
225253
}
226254

227255
FunctionPropertiesInfo FunctionPropertiesInfo::getFunctionPropertiesInfo(
228-
const Function &F, const DominatorTree &DT, const LoopInfo &LI) {
256+
const Function &F, const DominatorTree &DT, const LoopInfo &LI,
257+
const IR2VecVocabResult *VocabResult) {
229258

230259
FunctionPropertiesInfo FPI;
260+
if (VocabResult && VocabResult->isValid()) {
261+
FPI.IR2VecVocab = VocabResult->getVocabulary();
262+
FPI.FunctionEmbedding = ir2vec::Embedding(VocabResult->getDimension(), 0.0);
263+
}
231264
for (const auto &BB : F)
232265
if (DT.isReachableFromEntry(&BB))
233266
FPI.reIncludeBB(BB);
234267
FPI.updateAggregateStats(F, LI);
235268
return FPI;
236269
}
237270

271+
bool FunctionPropertiesInfo::operator==(
272+
const FunctionPropertiesInfo &FPI) const {
273+
if (BasicBlockCount != FPI.BasicBlockCount ||
274+
BlocksReachedFromConditionalInstruction !=
275+
FPI.BlocksReachedFromConditionalInstruction ||
276+
Uses != FPI.Uses ||
277+
DirectCallsToDefinedFunctions != FPI.DirectCallsToDefinedFunctions ||
278+
LoadInstCount != FPI.LoadInstCount ||
279+
StoreInstCount != FPI.StoreInstCount ||
280+
MaxLoopDepth != FPI.MaxLoopDepth ||
281+
TopLevelLoopCount != FPI.TopLevelLoopCount ||
282+
TotalInstructionCount != FPI.TotalInstructionCount ||
283+
BasicBlocksWithSingleSuccessor != FPI.BasicBlocksWithSingleSuccessor ||
284+
BasicBlocksWithTwoSuccessors != FPI.BasicBlocksWithTwoSuccessors ||
285+
BasicBlocksWithMoreThanTwoSuccessors !=
286+
FPI.BasicBlocksWithMoreThanTwoSuccessors ||
287+
BasicBlocksWithSinglePredecessor !=
288+
FPI.BasicBlocksWithSinglePredecessor ||
289+
BasicBlocksWithTwoPredecessors != FPI.BasicBlocksWithTwoPredecessors ||
290+
BasicBlocksWithMoreThanTwoPredecessors !=
291+
FPI.BasicBlocksWithMoreThanTwoPredecessors ||
292+
BigBasicBlocks != FPI.BigBasicBlocks ||
293+
MediumBasicBlocks != FPI.MediumBasicBlocks ||
294+
SmallBasicBlocks != FPI.SmallBasicBlocks ||
295+
CastInstructionCount != FPI.CastInstructionCount ||
296+
FloatingPointInstructionCount != FPI.FloatingPointInstructionCount ||
297+
IntegerInstructionCount != FPI.IntegerInstructionCount ||
298+
ConstantIntOperandCount != FPI.ConstantIntOperandCount ||
299+
ConstantFPOperandCount != FPI.ConstantFPOperandCount ||
300+
ConstantOperandCount != FPI.ConstantOperandCount ||
301+
InstructionOperandCount != FPI.InstructionOperandCount ||
302+
BasicBlockOperandCount != FPI.BasicBlockOperandCount ||
303+
GlobalValueOperandCount != FPI.GlobalValueOperandCount ||
304+
InlineAsmOperandCount != FPI.InlineAsmOperandCount ||
305+
ArgumentOperandCount != FPI.ArgumentOperandCount ||
306+
UnknownOperandCount != FPI.UnknownOperandCount ||
307+
CriticalEdgeCount != FPI.CriticalEdgeCount ||
308+
ControlFlowEdgeCount != FPI.ControlFlowEdgeCount ||
309+
UnconditionalBranchCount != FPI.UnconditionalBranchCount ||
310+
IntrinsicCount != FPI.IntrinsicCount ||
311+
DirectCallCount != FPI.DirectCallCount ||
312+
IndirectCallCount != FPI.IndirectCallCount ||
313+
CallReturnsIntegerCount != FPI.CallReturnsIntegerCount ||
314+
CallReturnsFloatCount != FPI.CallReturnsFloatCount ||
315+
CallReturnsPointerCount != FPI.CallReturnsPointerCount ||
316+
CallReturnsVectorIntCount != FPI.CallReturnsVectorIntCount ||
317+
CallReturnsVectorFloatCount != FPI.CallReturnsVectorFloatCount ||
318+
CallReturnsVectorPointerCount != FPI.CallReturnsVectorPointerCount ||
319+
CallWithManyArgumentsCount != FPI.CallWithManyArgumentsCount ||
320+
CallWithPointerArgumentCount != FPI.CallWithPointerArgumentCount) {
321+
return false;
322+
}
323+
// Check the equality of the function embeddings. We don't check the equality
324+
// of Vocabulary as it remains the same.
325+
if (!FunctionEmbedding.approximatelyEquals(FPI.FunctionEmbedding))
326+
return false;
327+
328+
return true;
329+
}
330+
238331
void FunctionPropertiesInfo::print(raw_ostream &OS) const {
239332
#define PRINT_PROPERTY(PROP_NAME) OS << #PROP_NAME ": " << PROP_NAME << "\n";
240333

@@ -322,6 +415,16 @@ FunctionPropertiesUpdater::FunctionPropertiesUpdater(
322415
// The caller's entry BB may change due to new alloca instructions.
323416
LikelyToChangeBBs.insert(&*Caller.begin());
324417

418+
// The users of the value returned by call instruction can change
419+
// leading to the change in embeddings being computed, when used.
420+
// We conservatively add the BBs with such uses to LikelyToChangeBBs.
421+
for (const auto *User : CB.users())
422+
CallUsers.insert(dyn_cast<Instruction>(User)->getParent());
423+
// CallSiteBB can be removed from CallUsers if present, it's taken care
424+
// separately.
425+
CallUsers.erase(&CallSiteBB);
426+
LikelyToChangeBBs.insert_range(CallUsers);
427+
325428
// The successors may become unreachable in the case of `invoke` inlining.
326429
// We track successors separately, too, because they form a boundary, together
327430
// with the CB BB ('Entry') between which the inlined callee will be pasted.
@@ -435,6 +538,9 @@ void FunctionPropertiesUpdater::finish(FunctionAnalysisManager &FAM) const {
435538
if (&CallSiteBB != &*Caller.begin())
436539
Reinclude.insert(&*Caller.begin());
437540

541+
// Reinclude the BBs which use the values returned by call instruction
542+
Reinclude.insert_range(CallUsers);
543+
438544
// Distribute the successors to the 2 buckets.
439545
for (const auto *Succ : Successors)
440546
if (DT.isReachableFromEntry(Succ))
@@ -486,6 +592,9 @@ bool FunctionPropertiesUpdater::isUpdateValid(Function &F,
486592
return false;
487593
DominatorTree DT(F);
488594
LoopInfo LI(DT);
489-
auto Fresh = FunctionPropertiesInfo::getFunctionPropertiesInfo(F, DT, LI);
595+
auto VocabResult = FAM.getResult<ModuleAnalysisManagerFunctionProxy>(F)
596+
.getCachedResult<IR2VecVocabAnalysis>(*F.getParent());
597+
auto Fresh =
598+
FunctionPropertiesInfo::getFunctionPropertiesInfo(F, DT, LI, VocabResult);
490599
return FPI == Fresh;
491600
}

llvm/lib/Analysis/InlineAdvisor.cpp

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include "llvm/ADT/StringExtras.h"
1717
#include "llvm/Analysis/AssumptionCache.h"
1818
#include "llvm/Analysis/EphemeralValuesCache.h"
19+
#include "llvm/Analysis/IR2Vec.h"
1920
#include "llvm/Analysis/InlineCost.h"
2021
#include "llvm/Analysis/OptimizationRemarkEmitter.h"
2122
#include "llvm/Analysis/ProfileSummaryInfo.h"
@@ -64,6 +65,13 @@ static cl::opt<bool>
6465
cl::desc("If true, annotate inline advisor remarks "
6566
"with LTO and pass information."));
6667

68+
// This flag is used to enable IR2Vec embeddings in the ML inliner; Only valid
69+
// with ML inliner. The vocab file is used to initialize the embeddings.
70+
static cl::opt<std::string> IR2VecVocabFile(
71+
"ml-inliner-ir2vec-vocab-file", cl::Hidden,
72+
cl::desc("Vocab file for IR2Vec; Setting this enables "
73+
"configuring the model to use IR2Vec embeddings."));
74+
6775
namespace llvm {
6876
extern cl::opt<InlinerFunctionImportStatsOpts> InlinerFunctionImportStats;
6977
} // namespace llvm
@@ -206,6 +214,20 @@ void InlineAdvice::recordInliningWithCalleeDeleted() {
206214
AnalysisKey InlineAdvisorAnalysis::Key;
207215
AnalysisKey PluginInlineAdvisorAnalysis::Key;
208216

217+
bool InlineAdvisorAnalysis::initializeIR2VecVocabIfRequested(
218+
Module &M, ModuleAnalysisManager &MAM) {
219+
if (!IR2VecVocabFile.empty()) {
220+
auto IR2VecVocabResult = MAM.getResult<IR2VecVocabAnalysis>(M);
221+
if (!IR2VecVocabResult.isValid()) {
222+
M.getContext().emitError("Failed to load IR2Vec vocabulary");
223+
return false;
224+
}
225+
}
226+
// No vocab file specified is OK; We just don't use IR2Vec
227+
// embeddings.
228+
return true;
229+
}
230+
209231
bool InlineAdvisorAnalysis::Result::tryCreate(
210232
InlineParams Params, InliningAdvisorMode Mode,
211233
const ReplayInlinerSettings &ReplaySettings, InlineContext IC) {
@@ -231,14 +253,21 @@ bool InlineAdvisorAnalysis::Result::tryCreate(
231253
/* EmitRemarks =*/true, IC);
232254
}
233255
break;
256+
// Run IR2VecVocabAnalysis once per module to get the vocabulary.
257+
// We run it here because it is immutable and we want to avoid running it
258+
// multiple times.
234259
case InliningAdvisorMode::Development:
235260
#ifdef LLVM_HAVE_TFLITE
236261
LLVM_DEBUG(dbgs() << "Using development-mode inliner policy.\n");
262+
if (!InlineAdvisorAnalysis::initializeIR2VecVocabIfRequested(M, MAM))
263+
return false;
237264
Advisor = llvm::getDevelopmentModeAdvisor(M, MAM, GetDefaultAdvice);
238265
#endif
239266
break;
240267
case InliningAdvisorMode::Release:
241268
LLVM_DEBUG(dbgs() << "Using release-mode inliner policy.\n");
269+
if (!InlineAdvisorAnalysis::initializeIR2VecVocabIfRequested(M, MAM))
270+
return false;
242271
Advisor = llvm::getReleaseModeAdvisor(M, MAM, GetDefaultAdvice);
243272
break;
244273
}

llvm/lib/Analysis/MLInlineAdvisor.cpp

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ static cl::opt<bool> KeepFPICache(
107107
cl::init(false));
108108

109109
// clang-format off
110-
const std::vector<TensorSpec> llvm::FeatureMap{
110+
std::vector<TensorSpec> llvm::FeatureMap{
111111
#define POPULATE_NAMES(DTYPE, SHAPE, NAME, __) TensorSpec::createSpec<DTYPE>(#NAME, SHAPE),
112112
// InlineCost features - these must come first
113113
INLINE_COST_FEATURE_ITERATOR(POPULATE_NAMES)
@@ -144,6 +144,7 @@ MLInlineAdvisor::MLInlineAdvisor(
144144
M, MAM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager()),
145145
ModelRunner(std::move(Runner)), GetDefaultAdvice(GetDefaultAdvice),
146146
CG(MAM.getResult<LazyCallGraphAnalysis>(M)),
147+
UseIR2Vec(MAM.getCachedResult<IR2VecVocabAnalysis>(M) != nullptr),
147148
InitialIRSize(getModuleIRSize()), CurrentIRSize(InitialIRSize),
148149
PSI(MAM.getResult<ProfileSummaryAnalysis>(M)) {
149150
assert(ModelRunner);
@@ -186,6 +187,19 @@ MLInlineAdvisor::MLInlineAdvisor(
186187
EdgeCount += getLocalCalls(KVP.first->getFunction());
187188
}
188189
NodeCount = AllNodes.size();
190+
191+
if (auto IR2VecVocabResult = MAM.getCachedResult<IR2VecVocabAnalysis>(M)) {
192+
if (!IR2VecVocabResult->isValid()) {
193+
M.getContext().emitError("IR2VecVocabAnalysis is not valid");
194+
return;
195+
}
196+
// Add the IR2Vec features to the feature map
197+
auto IR2VecDim = IR2VecVocabResult->getDimension();
198+
FeatureMap.push_back(
199+
TensorSpec::createSpec<float>("callee_embedding", {IR2VecDim}));
200+
FeatureMap.push_back(
201+
TensorSpec::createSpec<float>("caller_embedding", {IR2VecDim}));
202+
}
189203
}
190204

191205
unsigned MLInlineAdvisor::getInitialFunctionLevel(const Function &F) const {
@@ -433,6 +447,24 @@ std::unique_ptr<InlineAdvice> MLInlineAdvisor::getAdviceImpl(CallBase &CB) {
433447
*ModelRunner->getTensor<int64_t>(FeatureIndex::is_caller_avail_external) =
434448
Caller.hasAvailableExternallyLinkage();
435449

450+
if (UseIR2Vec) {
451+
// Python side expects float embeddings. The IR2Vec embeddings are doubles
452+
// as of now due to the restriction of fromJSON method used by the
453+
// readVocabulary method in ir2vec::Embeddings.
454+
auto setEmbedding = [&](const ir2vec::Embedding &Embedding,
455+
FeatureIndex Index) {
456+
auto Embedding_float =
457+
std::vector<float>(Embedding.begin(), Embedding.end());
458+
std::memcpy(ModelRunner->getTensor<float>(Index), Embedding_float.data(),
459+
Embedding.size() * sizeof(float));
460+
};
461+
462+
setEmbedding(CalleeBefore.getFunctionEmbedding(),
463+
FeatureIndex::callee_embedding);
464+
setEmbedding(CallerBefore.getFunctionEmbedding(),
465+
FeatureIndex::caller_embedding);
466+
}
467+
436468
// Add the cost features
437469
for (size_t I = 0;
438470
I < static_cast<size_t>(InlineCostFeatureIndex::NumberOfFeatures); ++I) {

0 commit comments

Comments
 (0)