Skip to content

Commit 55db1e0

Browse files
authored
Unrolled build for rust-lang#139700
Rollup merge of rust-lang#139700 - EnzymeAD:autodiff-flags, r=oli-obk Autodiff flags Interestingly, it seems that some other projects have conflicts with exactly the same LLVM optimization passes as autodiff. At least `LLVMRustOptimize` has exactly the flags that we need to disable problematic opt passes. This PR enables us to compile code where users differentiate two identical functions in the same module. This has been especially common in test cases, but it's not impossible to encounter in the wild. It also enables two new flags for testing/debugging. I consider writing an MCP to upgrade PrintPasses to be a standalone -Z flag, since it is *not* the same as `-Z print-llvm-passes`, which IMHO gives less useful output. A discussion can be found here: [#t-compiler/llvm > Print llvm passes. @ 💬](https://rust-lang.zulipchat.com/#narrow/channel/187780-t-compiler.2Fllvm/topic/Print.20llvm.20passes.2E/near/511533038) Finally, it improves `PrintModBefore` and `PrintModAfter`. They used to work reliable, but now we just schedule enzyme as part of an existing ModulePassManager (MPM). Since Enzyme is last in the MPM scheduling, PrintModBefore became very inaccurate. It used to print the input module, which we gave to the Enzyme and was great to create llvm-ir reproducer. However, lately the MPM would run the whole `default<O3>` pipeline, which heavily modifies the llvm module, before we pass it to Enzyme. That made it impossible to use the flag to create llvm-ir reproducers for Enzyme bugs. We now schedule a PrintModule pass just before Enzyme, solving this problem. Based on the PrintPass output, it also _seems_ like changing `registerEnzymeAndPassPipeline(PB, true);` to `registerEnzymeAndPassPipeline(PB, false);` has no effect. In theory, the bool should tell Enzyme to schedule some helpful passes in the PassBuilder. However, since it doesn't do anything and I'm not 100% sure anymore on whether we really need it, I'll just disable it for now and postpone investigations. r? ``@oli-obk`` closes rust-lang#139471 Tracking: - rust-lang#124509
2 parents 3c877f6 + f79a992 commit 55db1e0

File tree

8 files changed

+123
-23
lines changed

8 files changed

+123
-23
lines changed

compiler/rustc_codegen_llvm/src/back/lto.rs

+22-18
Original file line numberDiff line numberDiff line change
@@ -584,12 +584,10 @@ fn thin_lto(
584584
}
585585
}
586586

587-
fn enable_autodiff_settings(ad: &[config::AutoDiff], module: &mut ModuleCodegen<ModuleLlvm>) {
587+
fn enable_autodiff_settings(ad: &[config::AutoDiff]) {
588588
for &val in ad {
589+
// We intentionally don't use a wildcard, to not forget handling anything new.
589590
match val {
590-
config::AutoDiff::PrintModBefore => {
591-
unsafe { llvm::LLVMDumpModule(module.module_llvm.llmod()) };
592-
}
593591
config::AutoDiff::PrintPerf => {
594592
llvm::set_print_perf(true);
595593
}
@@ -603,17 +601,23 @@ fn enable_autodiff_settings(ad: &[config::AutoDiff], module: &mut ModuleCodegen<
603601
llvm::set_inline(true);
604602
}
605603
config::AutoDiff::LooseTypes => {
606-
llvm::set_loose_types(false);
604+
llvm::set_loose_types(true);
607605
}
608606
config::AutoDiff::PrintSteps => {
609607
llvm::set_print(true);
610608
}
611-
// We handle this below
609+
// We handle this in the PassWrapper.cpp
610+
config::AutoDiff::PrintPasses => {}
611+
// We handle this in the PassWrapper.cpp
612+
config::AutoDiff::PrintModBefore => {}
613+
// We handle this in the PassWrapper.cpp
612614
config::AutoDiff::PrintModAfter => {}
613-
// We handle this below
615+
// We handle this in the PassWrapper.cpp
614616
config::AutoDiff::PrintModFinal => {}
615617
// This is required and already checked
616618
config::AutoDiff::Enable => {}
619+
// We handle this below
620+
config::AutoDiff::NoPostopt => {}
617621
}
618622
}
619623
// This helps with handling enums for now.
@@ -647,27 +651,27 @@ pub(crate) fn run_pass_manager(
647651
// We then run the llvm_optimize function a second time, to optimize the code which we generated
648652
// in the enzyme differentiation pass.
649653
let enable_ad = config.autodiff.contains(&config::AutoDiff::Enable);
650-
let stage =
651-
if enable_ad { write::AutodiffStage::DuringAD } else { write::AutodiffStage::PostAD };
654+
let stage = if thin {
655+
write::AutodiffStage::PreAD
656+
} else {
657+
if enable_ad { write::AutodiffStage::DuringAD } else { write::AutodiffStage::PostAD }
658+
};
652659

653660
if enable_ad {
654-
enable_autodiff_settings(&config.autodiff, module);
661+
enable_autodiff_settings(&config.autodiff);
655662
}
656663

657664
unsafe {
658665
write::llvm_optimize(cgcx, dcx, module, None, config, opt_level, opt_stage, stage)?;
659666
}
660667

661-
if cfg!(llvm_enzyme) && enable_ad {
662-
// This is the post-autodiff IR, mainly used for testing and educational purposes.
663-
if config.autodiff.contains(&config::AutoDiff::PrintModAfter) {
664-
unsafe { llvm::LLVMDumpModule(module.module_llvm.llmod()) };
665-
}
666-
668+
if cfg!(llvm_enzyme) && enable_ad && !thin {
667669
let opt_stage = llvm::OptStage::FatLTO;
668670
let stage = write::AutodiffStage::PostAD;
669-
unsafe {
670-
write::llvm_optimize(cgcx, dcx, module, None, config, opt_level, opt_stage, stage)?;
671+
if !config.autodiff.contains(&config::AutoDiff::NoPostopt) {
672+
unsafe {
673+
write::llvm_optimize(cgcx, dcx, module, None, config, opt_level, opt_stage, stage)?;
674+
}
671675
}
672676

673677
// This is the final IR, so people should be able to inspect the optimized autodiff output,

compiler/rustc_codegen_llvm/src/back/write.rs

+15-1
Original file line numberDiff line numberDiff line change
@@ -572,20 +572,31 @@ pub(crate) unsafe fn llvm_optimize(
572572

573573
let consider_ad = cfg!(llvm_enzyme) && config.autodiff.contains(&config::AutoDiff::Enable);
574574
let run_enzyme = autodiff_stage == AutodiffStage::DuringAD;
575+
let print_before_enzyme = config.autodiff.contains(&config::AutoDiff::PrintModBefore);
576+
let print_after_enzyme = config.autodiff.contains(&config::AutoDiff::PrintModAfter);
577+
let print_passes = config.autodiff.contains(&config::AutoDiff::PrintPasses);
578+
let merge_functions;
575579
let unroll_loops;
576580
let vectorize_slp;
577581
let vectorize_loop;
578582

579583
// When we build rustc with enzyme/autodiff support, we want to postpone size-increasing
580584
// optimizations until after differentiation. Our pipeline is thus: (opt + enzyme), (full opt).
581585
// We therefore have two calls to llvm_optimize, if autodiff is used.
586+
//
587+
// We also must disable merge_functions, since autodiff placeholder/dummy bodies tend to be
588+
// identical. We run opts before AD, so there is a chance that LLVM will merge our dummies.
589+
// In that case, we lack some dummy bodies and can't replace them with the real AD code anymore.
590+
// We then would need to abort compilation. This was especially common in test cases.
582591
if consider_ad && autodiff_stage != AutodiffStage::PostAD {
592+
merge_functions = false;
583593
unroll_loops = false;
584594
vectorize_slp = false;
585595
vectorize_loop = false;
586596
} else {
587597
unroll_loops =
588598
opt_level != config::OptLevel::Size && opt_level != config::OptLevel::SizeMin;
599+
merge_functions = config.merge_functions;
589600
vectorize_slp = config.vectorize_slp;
590601
vectorize_loop = config.vectorize_loop;
591602
}
@@ -663,13 +674,16 @@ pub(crate) unsafe fn llvm_optimize(
663674
thin_lto_buffer,
664675
config.emit_thin_lto,
665676
config.emit_thin_lto_summary,
666-
config.merge_functions,
677+
merge_functions,
667678
unroll_loops,
668679
vectorize_slp,
669680
vectorize_loop,
670681
config.no_builtins,
671682
config.emit_lifetime_markers,
672683
run_enzyme,
684+
print_before_enzyme,
685+
print_after_enzyme,
686+
print_passes,
673687
sanitizer_options.as_ref(),
674688
pgo_gen_path.as_ref().map_or(std::ptr::null(), |s| s.as_ptr()),
675689
pgo_use_path.as_ref().map_or(std::ptr::null(), |s| s.as_ptr()),

compiler/rustc_codegen_llvm/src/builder/autodiff.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -473,7 +473,7 @@ pub(crate) fn differentiate<'ll>(
473473
return Err(diag_handler.handle().emit_almost_fatal(AutoDiffWithoutEnable));
474474
}
475475

476-
// Before dumping the module, we want all the TypeTrees to become part of the module.
476+
// Here we replace the placeholder code with the actual autodiff code, which calls Enzyme.
477477
for item in diff_items.iter() {
478478
let name = item.source.clone();
479479
let fn_def: Option<&llvm::Value> = cx.get_function(&name);

compiler/rustc_codegen_llvm/src/llvm/ffi.rs

+3
Original file line numberDiff line numberDiff line change
@@ -2454,6 +2454,9 @@ unsafe extern "C" {
24542454
DisableSimplifyLibCalls: bool,
24552455
EmitLifetimeMarkers: bool,
24562456
RunEnzyme: bool,
2457+
PrintBeforeEnzyme: bool,
2458+
PrintAfterEnzyme: bool,
2459+
PrintPasses: bool,
24572460
SanitizerOptions: Option<&SanitizerOptions>,
24582461
PGOGenPath: *const c_char,
24592462
PGOUsePath: *const c_char,

compiler/rustc_llvm/llvm-wrapper/PassWrapper.cpp

+28-2
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include "llvm/IR/LegacyPassManager.h"
1515
#include "llvm/IR/PassManager.h"
1616
#include "llvm/IR/Verifier.h"
17+
#include "llvm/IRPrinter/IRPrintingPasses.h"
1718
#include "llvm/LTO/LTO.h"
1819
#include "llvm/MC/MCSubtargetInfo.h"
1920
#include "llvm/MC/TargetRegistry.h"
@@ -703,7 +704,8 @@ extern "C" LLVMRustResult LLVMRustOptimize(
703704
bool LintIR, LLVMRustThinLTOBuffer **ThinLTOBufferRef, bool EmitThinLTO,
704705
bool EmitThinLTOSummary, bool MergeFunctions, bool UnrollLoops,
705706
bool SLPVectorize, bool LoopVectorize, bool DisableSimplifyLibCalls,
706-
bool EmitLifetimeMarkers, bool RunEnzyme,
707+
bool EmitLifetimeMarkers, bool RunEnzyme, bool PrintBeforeEnzyme,
708+
bool PrintAfterEnzyme, bool PrintPasses,
707709
LLVMRustSanitizerOptions *SanitizerOptions, const char *PGOGenPath,
708710
const char *PGOUsePath, bool InstrumentCoverage,
709711
const char *InstrProfileOutput, const char *PGOSampleUsePath,
@@ -1048,14 +1050,38 @@ extern "C" LLVMRustResult LLVMRustOptimize(
10481050
// now load "-enzyme" pass:
10491051
#ifdef ENZYME
10501052
if (RunEnzyme) {
1051-
registerEnzymeAndPassPipeline(PB, true);
1053+
1054+
if (PrintBeforeEnzyme) {
1055+
// Handle the Rust flag `-Zautodiff=PrintModBefore`.
1056+
std::string Banner = "Module before EnzymeNewPM";
1057+
MPM.addPass(PrintModulePass(outs(), Banner, true, false));
1058+
}
1059+
1060+
registerEnzymeAndPassPipeline(PB, false);
10521061
if (auto Err = PB.parsePassPipeline(MPM, "enzyme")) {
10531062
std::string ErrMsg = toString(std::move(Err));
10541063
LLVMRustSetLastError(ErrMsg.c_str());
10551064
return LLVMRustResult::Failure;
10561065
}
1066+
1067+
if (PrintAfterEnzyme) {
1068+
// Handle the Rust flag `-Zautodiff=PrintModAfter`.
1069+
std::string Banner = "Module after EnzymeNewPM";
1070+
MPM.addPass(PrintModulePass(outs(), Banner, true, false));
1071+
}
10571072
}
10581073
#endif
1074+
if (PrintPasses) {
1075+
// Print all passes from the PM:
1076+
std::string Pipeline;
1077+
raw_string_ostream SOS(Pipeline);
1078+
MPM.printPipeline(SOS, [&PIC](StringRef ClassName) {
1079+
auto PassName = PIC.getPassNameForClassName(ClassName);
1080+
return PassName.empty() ? ClassName : PassName;
1081+
});
1082+
outs() << Pipeline;
1083+
outs() << "\n";
1084+
}
10591085

10601086
// Upgrade all calls to old intrinsics first.
10611087
for (Module::iterator I = TheModule->begin(), E = TheModule->end(); I != E;)

compiler/rustc_session/src/config.rs

+4
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,10 @@ pub enum AutoDiff {
246246
/// Print the module after running autodiff and optimizations.
247247
PrintModFinal,
248248

249+
/// Print all passes scheduled by LLVM
250+
PrintPasses,
251+
/// Disable extra opt run after running autodiff
252+
NoPostopt,
249253
/// Enzyme's loose type debug helper (can cause incorrect gradients!!)
250254
/// Usable in cases where Enzyme errors with `can not deduce type of X`.
251255
LooseTypes,

compiler/rustc_session/src/options.rs

+5-1
Original file line numberDiff line numberDiff line change
@@ -711,7 +711,7 @@ mod desc {
711711
pub(crate) const parse_list: &str = "a space-separated list of strings";
712712
pub(crate) const parse_list_with_polarity: &str =
713713
"a comma-separated list of strings, with elements beginning with + or -";
714-
pub(crate) const parse_autodiff: &str = "a comma separated list of settings: `Enable`, `PrintSteps`, `PrintTA`, `PrintAA`, `PrintPerf`, `PrintModBefore`, `PrintModAfter`, `PrintModFinal`, `LooseTypes`, `Inline`";
714+
pub(crate) const parse_autodiff: &str = "a comma separated list of settings: `Enable`, `PrintSteps`, `PrintTA`, `PrintAA`, `PrintPerf`, `PrintModBefore`, `PrintModAfter`, `PrintModFinal`, `PrintPasses`, `NoPostopt`, `LooseTypes`, `Inline`";
715715
pub(crate) const parse_comma_list: &str = "a comma-separated list of strings";
716716
pub(crate) const parse_opt_comma_list: &str = parse_comma_list;
717717
pub(crate) const parse_number: &str = "a number";
@@ -1360,6 +1360,8 @@ pub mod parse {
13601360
"PrintModBefore" => AutoDiff::PrintModBefore,
13611361
"PrintModAfter" => AutoDiff::PrintModAfter,
13621362
"PrintModFinal" => AutoDiff::PrintModFinal,
1363+
"NoPostopt" => AutoDiff::NoPostopt,
1364+
"PrintPasses" => AutoDiff::PrintPasses,
13631365
"LooseTypes" => AutoDiff::LooseTypes,
13641366
"Inline" => AutoDiff::Inline,
13651367
_ => {
@@ -2098,6 +2100,8 @@ options! {
20982100
`=PrintModBefore`
20992101
`=PrintModAfter`
21002102
`=PrintModFinal`
2103+
`=PrintPasses`,
2104+
`=NoPostopt`
21012105
`=LooseTypes`
21022106
`=Inline`
21032107
Multiple options can be combined with commas."),
+45
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
//@ compile-flags: -Zautodiff=Enable -C opt-level=3 -Clto=fat
2+
//@ no-prefer-dynamic
3+
//@ needs-enzyme
4+
//
5+
// Each autodiff invocation creates a new placeholder function, which we will replace on llvm-ir
6+
// level. If a user tries to differentiate two identical functions within the same compilation unit,
7+
// then LLVM might merge them in release mode before AD. In that case we can't rewrite one of the
8+
// merged placeholder function anymore, and compilation would fail. We prevent this by disabling
9+
// LLVM's merge_function pass before AD. Here we implicetely test that our solution keeps working.
10+
// We also explicetly test that we keep running merge_function after AD, by checking for two
11+
// identical function calls in the LLVM-IR, while having two different calls in the Rust code.
12+
#![feature(autodiff)]
13+
14+
use std::autodiff::autodiff;
15+
16+
#[autodiff(d_square, Reverse, Duplicated, Active)]
17+
fn square(x: &f64) -> f64 {
18+
x * x
19+
}
20+
21+
#[autodiff(d_square2, Reverse, Duplicated, Active)]
22+
fn square2(x: &f64) -> f64 {
23+
x * x
24+
}
25+
26+
// CHECK:; identical_fnc::main
27+
// CHECK-NEXT:; Function Attrs:
28+
// CHECK-NEXT:define internal void @_ZN13identical_fnc4main17hf4dbc69c8d2f9130E()
29+
// CHECK-NEXT:start:
30+
// CHECK-NOT:br
31+
// CHECK-NOT:ret
32+
// CHECK:; call identical_fnc::d_square
33+
// CHECK-NEXT: call fastcc void @_ZN13identical_fnc8d_square17h4c364207a2f8e06dE(double %x.val, ptr noalias noundef nonnull align 8 dereferenceable(8) %dx1)
34+
// CHECK-NEXT:; call identical_fnc::d_square
35+
// CHECK-NEXT: call fastcc void @_ZN13identical_fnc8d_square17h4c364207a2f8e06dE(double %x.val, ptr noalias noundef nonnull align 8 dereferenceable(8) %dx2)
36+
37+
fn main() {
38+
let x = std::hint::black_box(3.0);
39+
let mut dx1 = std::hint::black_box(1.0);
40+
let mut dx2 = std::hint::black_box(1.0);
41+
let _ = d_square(&x, &mut dx1, 1.0);
42+
let _ = d_square2(&x, &mut dx2, 1.0);
43+
assert_eq!(dx1, 6.0);
44+
assert_eq!(dx2, 6.0);
45+
}

0 commit comments

Comments
 (0)