Skip to content

Refactor inliner to go over the blocks in topo sort order. #7890

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: ilya/inlined_blocks
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
333 changes: 150 additions & 183 deletions crates/cairo-lang-lowering/src/inline/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ mod test;

pub mod statements_weights;

use std::collections::{HashMap, VecDeque};
use std::collections::HashMap;

use cairo_lang_defs::diagnostic_utils::StableLocation;
use cairo_lang_defs::ids::LanguageElementId;
Expand All @@ -12,9 +12,9 @@ use cairo_lang_semantic::items::functions::InlineConfiguration;
use cairo_lang_utils::LookupIntern;
use cairo_lang_utils::casts::IntoOrPanic;
use cairo_lang_utils::ordered_hash_map::OrderedHashMap;
use itertools::zip_eq;
use itertools::{Itertools, zip_eq};

use crate::blocks::{Blocks, BlocksBuilder};
use crate::blocks::BlocksBuilder;
use crate::db::LoweringGroup;
use crate::diagnostic::{
LoweringDiagnostic, LoweringDiagnosticKind, LoweringDiagnostics, LoweringDiagnosticsBuilder,
Expand Down Expand Up @@ -89,70 +89,49 @@ fn should_inline_lowered(
let weight_of_blocks = db.estimate_size(function_id)?;
Ok(weight_of_blocks < inline_small_functions_threshold.into_or_panic())
}

// TODO(ilya): Add Rewriter trait.

/// A rewriter that inlines functions annotated with #[inline(always)].
pub struct FunctionInlinerRewriter<'db> {
/// The LoweringContext where we are building the new blocks.
variables: VariableAllocator<'db>,
/// A Queue of blocks on which we want to apply the FunctionInlinerRewriter.
block_queue: BlockRewriteQueue,
/// rewritten statements.
statements: Vec<Statement>,
/// The end of the current block.
block_end: BlockEnd,

/// Indicates whether the current block should be finalized or added to the block_queue.
///
/// When we split a block for inlining, the begging of the block should be finzalized to keep
/// the block_id, but the following block are created through the queue to avoid shifting
/// the id of the blocks in the queue.
finalize: bool,

/// The processed statements of the current block.
unprocessed_statements: <Vec<Statement> as IntoIterator>::IntoIter,
/// Indicates that the inlining process was successful.
inlining_success: Maybe<()>,
/// The id of the function calling the possibly inlined functions.
calling_function_id: ConcreteFunctionWithBodyId,
}

pub struct BlockRewriteQueue {
/// A Queue of blocks that require processing, and their id.
block_queue: VecDeque<Block>,
/// The new blocks that were created during the inlining.
blocks: BlocksBuilder,
}
impl BlockRewriteQueue {
/// Enqueues the block for processing and returns the block_id that this
/// block is going to get in self.blocks.
fn enqueue_block(&mut self, block: Block) {
self.block_queue.push_back(block);
}
/// Pops a block requiring rewrites from the queue.
/// If the block doesn't require rewrites, it is finalized and added to the blocks.
fn dequeue(&mut self) -> Option<Block> {
self.block_queue.pop_front()
}
/// Finalizes a block.
fn finalize(&mut self, block: Block) {
self.blocks.alloc(block);
}
}

/// Context for mapping ids from `lowered` to a new `Lowered` object.
pub struct Mapper<'a, 'b> {
variables: &'a mut VariableAllocator<'b>,
lowered: &'a Lowered,
renamed_vars: HashMap<VariableId, VariableId>,
return_block_id: BlockId,
outputs: &'a [id_arena::Id<crate::Variable>],

outputs: Vec<VariableId>,
inlining_location: StableLocation,

/// An offset that is added to all the block IDs in order to translate them into the new
/// lowering representation.
block_id_offset: BlockId,

// Return statements are replaced with goto to this block with the appropriate remapping.
return_block_id: BlockId,
}

impl<'a, 'b> Mapper<'a, 'b> {
pub fn new(
variables: &'a mut VariableAllocator<'b>,
lowered: &'a Lowered,
call_stmt: &StatementCall,
block_id_offset: usize,
) -> Self {
// The input variables need to be renamed to match the inputs to the function call.
let renamed_vars = HashMap::<VariableId, VariableId>::from_iter(zip_eq(
lowered.parameters.iter().cloned(),
call_stmt.inputs.iter().map(|var_usage| var_usage.var_id),
));

let db = variables.db;
let inlining_location = call_stmt.location.lookup_intern(db).stable_location;

Self {
variables,
lowered,
renamed_vars,
block_id_offset: BlockId(block_id_offset),
return_block_id: BlockId(block_id_offset + lowered.blocks.len()),
outputs: call_stmt.outputs.clone(),
inlining_location,
}
}
}

impl Rebuilder for Mapper<'_, '_> {
Expand Down Expand Up @@ -198,146 +177,134 @@ impl Rebuilder for Mapper<'_, '_> {
}
}

impl<'db> FunctionInlinerRewriter<'db> {
fn apply(
variables: VariableAllocator<'db>,
lowered: &Lowered,
calling_function_id: ConcreteFunctionWithBodyId,
) -> Maybe<Lowered> {
let mut rewriter = Self {
variables,
block_queue: BlockRewriteQueue {
block_queue: lowered.blocks.iter().map(|(_, b)| b.clone()).collect(),
blocks: BlocksBuilder::new(),
},
statements: vec![],
block_end: BlockEnd::NotSet,
unprocessed_statements: Default::default(),
inlining_success: lowered.blocks.has_root(),
calling_function_id,
finalize: true,
};

rewriter.variables.variables = lowered.variables.clone();
while let Some(block) = rewriter.block_queue.dequeue() {
rewriter.block_end = block.end;
rewriter.unprocessed_statements = block.statements.into_iter();

while let Some(statement) = rewriter.unprocessed_statements.next() {
rewriter.rewrite(statement)?;
fn inner_apply_inlining(
mut variables: VariableAllocator<'_>,
lowered: &Lowered,
calling_function_id: ConcreteFunctionWithBodyId,
) -> Maybe<Lowered> {
lowered.blocks.has_root()?;

variables.variables = lowered.variables.clone();
let mut blocks = BlocksBuilder::new();

let mut stack: Vec<std::vec::IntoIter<BlockId>> = vec![
lowered
.blocks
.iter()
.map(|(_, block)| blocks.alloc(block.clone()))
.collect_vec()
.into_iter(),
];

// Used to keep track of the next block_id while `block` is borrowed.
let mut next_block_id = blocks.len();

while let Some(mut func_blocks) = stack.pop() {
for block_id in func_blocks.by_ref() {
let block = blocks.get_mut_block(block_id);

let mut opt_inline_info = None;

for (idx, statement) in block.statements.iter().enumerate() {
if let Some((call_stmt, called_func)) =
should_inline(variables.db, calling_function_id, statement)?
{
opt_inline_info = Some((idx, call_stmt, called_func));
break;
}
}

let new_block = Block {
statements: std::mem::take(&mut rewriter.statements),
end: rewriter.block_end,
let Some((call_stmt_idx, call_stmt, called_func)) = opt_inline_info else {
// Nothing to inline in this block, go to the next block.
continue;
};

if !rewriter.finalize {
rewriter.block_queue.enqueue_block(new_block);
rewriter.finalize = true;
} else {
rewriter.block_queue.finalize(new_block);
}
}
let inlined_lowered =
variables.db.lowered_body(called_func, LoweringStage::PostBaseline)?;
inlined_lowered.blocks.has_root()?;

let blocks = rewriter
.inlining_success
.map(|()| rewriter.block_queue.blocks.build().unwrap())
.unwrap_or_else(Blocks::new_errored);

Ok(Lowered {
diagnostics: lowered.diagnostics.clone(),
variables: rewriter.variables.variables,
blocks,
parameters: lowered.parameters.clone(),
signature: lowered.signature.clone(),
})
}
let mut inline_mapper =
Mapper::new(&mut variables, &inlined_lowered, call_stmt, next_block_id);

/// Rewrites a statement and either appends it to self.statements or adds new statements to
/// self.statements_rewrite_stack.
fn rewrite(&mut self, statement: Statement) -> Maybe<()> {
if let Statement::Call(ref stmt) = statement {
if !stmt.with_coupon {
if let Some(called_func) = stmt.function.body(self.variables.db)? {
if let crate::ids::ConcreteFunctionWithBodyLongId::Specialized(specialized) =
self.calling_function_id.lookup_intern(self.variables.db)
{
if specialized.base == called_func {
// A specialized function should always inline its base.
return self.inline_function(called_func, stmt);
}
}

// TODO: Implement better logic to avoid inlining of destructors that call
// themselves.
if called_func != self.calling_function_id
&& self.variables.db.priv_should_inline(called_func)?
{
return self.inline_function(called_func, stmt);
}
}
}
}
// drain the statements starting at the call to the inlined function and replace the end
// of the block with a goto to the root block of the inlined function.
let remaining_statements =
block.statements.drain(call_stmt_idx..).skip(1).collect_vec();
let orig_block_end = std::mem::replace(
&mut block.end,
BlockEnd::Goto(inline_mapper.block_id_offset, VarRemapping::default()),
);

self.statements.push(statement);
Ok(())
// Apply the mapper to the inlined blocks and add them as a contiguous chunk to the
// blocks builder.
let mut inlined_blocks_ids = inlined_lowered
.blocks
.iter()
.map(|(_block_id, block)| blocks.alloc(inline_mapper.rebuild_block(block)))
.collect_vec();

// Move the remaining statements and the original block end to a new return block, that
// is right after the inlined blocks.
let return_block_id =
blocks.alloc(Block { statements: remaining_statements, end: orig_block_end });

// Append the id of the return block to the list of blocks in the inlined function.
// It is not part of that function, but we want to visit it right after the inlined
// function blocks.
inlined_blocks_ids.push(return_block_id);

// Update the `next_block_id`
next_block_id = blocks.len();

// Return the remaining blocks from the current function to the stack and add the blocks
// of the inlined function to the top of the stack.
stack.push(func_blocks);
stack.push(inlined_blocks_ids.into_iter());

break;
}
}

/// Inlines the given function call.
pub fn inline_function(
&mut self,
function_id: ConcreteFunctionWithBodyId,
call_stmt: &StatementCall,
) -> Maybe<()> {
let lowered = self.variables.db.lowered_body(function_id, LoweringStage::PostBaseline)?;
lowered.blocks.has_root()?;

// As the block_ids and variable_ids are per function, we need to rename all
// the blocks and variables before we enqueue the blocks from the function that
// we are inlining.

// The input variables need to be renamed to match the inputs to the function call.
let renamed_vars = HashMap::<VariableId, VariableId>::from_iter(zip_eq(
lowered.parameters.iter().cloned(),
call_stmt.inputs.iter().map(|var_usage| var_usage.var_id),
));

let db = self.variables.db;
let inlining_location = call_stmt.location.lookup_intern(db).stable_location;
Ok(Lowered {
diagnostics: lowered.diagnostics.clone(),
variables: variables.variables,
blocks: blocks.build().unwrap(),
parameters: lowered.parameters.clone(),
signature: lowered.signature.clone(),
})
}

// The block_id_offset is the id of the first block in the new function, there is a `+1`
// because of the `new_block` bellow.
let block_id_offset =
self.block_queue.blocks.len() + self.block_queue.block_queue.len() + 1;
let new_block = Block {
statements: std::mem::take(&mut self.statements),
end: BlockEnd::Goto(BlockId(block_id_offset), VarRemapping::default()),
};
if self.finalize {
self.block_queue.finalize(new_block);
self.finalize = false;
} else {
self.block_queue.enqueue_block(new_block);
/// Rewrites a statement and either appends it to self.statements or adds new statements to
/// self.statements_rewrite_stack.
fn should_inline<'a>(
db: &dyn LoweringGroup,
calling_function_id: ConcreteFunctionWithBodyId,
statement: &'a Statement,
) -> Maybe<Option<(&'a StatementCall, ConcreteFunctionWithBodyId)>> {
if let Statement::Call(stmt) = statement {
if stmt.with_coupon {
return Ok(None);
}

let mut mapper = Mapper {
variables: &mut self.variables,
lowered: &lowered,
renamed_vars,
block_id_offset: BlockId(block_id_offset),
return_block_id: BlockId(block_id_offset + lowered.blocks.len()),
outputs: &call_stmt.outputs,
inlining_location,
};
if let Some(called_func) = stmt.function.body(db)? {
if let crate::ids::ConcreteFunctionWithBodyLongId::Specialized(specialized) =
calling_function_id.lookup_intern(db)
{
if specialized.base == called_func {
// A specialized function should always inline its base.
return Ok(Some((stmt, called_func)));
}
}

for (_, block) in lowered.blocks.iter() {
let block = mapper.rebuild_block(block);
self.block_queue.enqueue_block(block);
// TODO: Implement better logic to avoid inlining of destructors that call
// themselves.
if called_func != calling_function_id && db.priv_should_inline(called_func)? {
return Ok(Some((stmt, called_func)));
}
}

Ok(())
}

Ok(None)
}

pub fn apply_inlining(
Expand All @@ -350,7 +317,7 @@ pub fn apply_inlining(
function_id.base_semantic_function(db).function_with_body_id(db),
lowered.variables.clone(),
)?;
if let Ok(new_lowered) = FunctionInlinerRewriter::apply(variables, lowered, function_id) {
if let Ok(new_lowered) = inner_apply_inlining(variables, lowered, function_id) {
*lowered = new_lowered;
}
Ok(())
Expand Down
Loading