Skip to content

Commit 248dbbb

Browse files
authored
Fix insert choice (rust-lang#988)
* fix insert choice * fix tests
1 parent 23d6ecc commit 248dbbb

File tree

7 files changed

+55
-88
lines changed

7 files changed

+55
-88
lines changed

enzyme/Enzyme/TraceUtils.h

Lines changed: 17 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -489,36 +489,35 @@ class TraceUtils {
489489

490490
CallInst *InsertChoice(IRBuilder<> &Builder, Value *address, Value *score,
491491
Value *choice) {
492-
auto size = choice->getType()->getPrimitiveSizeInBits() / 8;
493492
Type *size_type = interface->getChoiceTy()->getParamType(3);
494-
495-
auto M = interface->getSampleFunction()->getParent();
496-
auto &DL = M->getDataLayout();
497-
auto pointersize = DL.getPointerSizeInBits();
493+
auto choicesize = choice->getType()->getPrimitiveSizeInBits();
498494

499495
Value *retval;
500496
if (choice->getType()->isPointerTy()) {
501497
retval = Builder.CreatePointerCast(choice, Builder.getInt8PtrTy());
502498
} else {
503-
IRBuilder<> AllocaBuilder(
504-
newFunc->getEntryBlock().getFirstNonPHIOrDbgOrLifetime());
505-
auto alloca = AllocaBuilder.CreateAlloca(choice->getType(), nullptr,
506-
choice->getName() + ".ptr");
507-
Builder.CreateStore(choice, alloca);
508-
bool fitsInPointer =
509-
choice->getType()->getPrimitiveSizeInBits() == pointersize;
510-
if (fitsInPointer) {
511-
auto dblptr =
512-
PointerType::get(Builder.getInt8PtrTy(), DL.getAllocaAddrSpace());
513-
retval = Builder.CreateLoad(Builder.getInt8PtrTy(),
514-
Builder.CreatePointerCast(alloca, dblptr));
499+
auto M = interface->getSampleFunction()->getParent();
500+
auto &DL = M->getDataLayout();
501+
auto pointersize = DL.getPointerSizeInBits();
502+
if (choicesize <= pointersize) {
503+
auto cast = Builder.CreateBitCast(
504+
choice, IntegerType::get(M->getContext(), choicesize));
505+
cast = choicesize == pointersize
506+
? cast
507+
: Builder.CreateZExt(cast, Builder.getIntPtrTy(DL));
508+
retval = Builder.CreateIntToPtr(cast, Builder.getInt8PtrTy());
515509
} else {
510+
IRBuilder<> AllocaBuilder(
511+
newFunc->getEntryBlock().getFirstNonPHIOrDbgOrLifetime());
512+
auto alloca = AllocaBuilder.CreateAlloca(choice->getType(), nullptr,
513+
choice->getName() + ".ptr");
514+
Builder.CreateStore(choice, alloca);
516515
retval = alloca;
517516
}
518517
}
519518

520519
Value *args[] = {trace, address, score, retval,
521-
ConstantInt::get(size_type, size)};
520+
ConstantInt::get(size_type, choicesize / 8)};
522521

523522
auto call = Builder.CreateCall(interface->insertChoiceTy(),
524523
interface->insertChoice(), args);

enzyme/test/Enzyme/ProbProg/condition-dynamic.ll

Lines changed: 9 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -86,12 +86,10 @@ entry:
8686
; CHECK-NEXT: %4 = getelementptr inbounds i8*, i8** %interface, i32 6
8787
; CHECK-NEXT: %5 = load i8*, i8** %4
8888
; CHECK-NEXT: %has_call = bitcast i8* %5 to i1 (i8*, i8*)*
89-
; CHECK-NEXT: %call1.ptr3 = alloca double
9089
; CHECK-NEXT: %call1.ptr = alloca double
9190
; CHECK-NEXT: %6 = getelementptr inbounds i8*, i8** %interface, i32 3
9291
; CHECK-NEXT: %7 = load i8*, i8** %6
9392
; CHECK-NEXT: %insert_choice = bitcast i8* %7 to void (i8*, i8*, double, i8*, i64)*
94-
; CHECK-NEXT: %call.ptr2 = alloca double
9593
; CHECK-NEXT: %8 = getelementptr inbounds i8*, i8** %interface, i32 1
9694
; CHECK-NEXT: %9 = load i8*, i8** %8
9795
; CHECK-NEXT: %get_choice = bitcast i8* %9 to i64 (i8*, i8*, i8*, i64)*
@@ -119,9 +117,8 @@ entry:
119117
; CHECK: entry.cntd: ; preds = %condition.call.without.trace, %condition.call.with.trace
120118
; CHECK-NEXT: %call = phi double [ %from.trace.call, %condition.call.with.trace ], [ %sample.call, %condition.call.without.trace ]
121119
; CHECK-NEXT: %likelihood.call = call double @normal_logpdf(double 0.000000e+00, double 1.000000e+00, double %call)
122-
; CHECK-NEXT: store double %call, double* %call.ptr2
123-
; CHECK-NEXT: %15 = bitcast double* %call.ptr2 to i8**
124-
; CHECK-NEXT: %16 = load i8*, i8** %15
120+
; CHECK-NEXT: %15 = bitcast double %call to i64
121+
; CHECK-NEXT: %16 = inttoptr i64 %15 to i8*
125122
; CHECK-NEXT: call void %insert_choice(i8* %trace, i8* nocapture readonly getelementptr inbounds ([2 x i8], [2 x i8]* @.str.1, i64 0, i64 0), double %likelihood.call, i8* %16, i64 8)
126123
; CHECK-NEXT: %has.choice.call1 = call i1 %has_choice(i8* %observations, i8* nocapture readonly getelementptr inbounds ([2 x i8], [2 x i8]* @.str.2, i64 0, i64 0))
127124
; CHECK-NEXT: br i1 %has.choice.call1, label %condition.call1.with.trace, label %condition.call1.without.trace
@@ -139,9 +136,8 @@ entry:
139136
; CHECK: entry.cntd.cntd: ; preds = %condition.call1.without.trace, %condition.call1.with.trace
140137
; CHECK-NEXT: %call1 = phi double [ %from.trace.call1, %condition.call1.with.trace ], [ %sample.call1, %condition.call1.without.trace ]
141138
; CHECK-NEXT: %likelihood.call1 = call double @normal_logpdf(double 0.000000e+00, double 1.000000e+00, double %call1)
142-
; CHECK-NEXT: store double %call1, double* %call1.ptr3
143-
; CHECK-NEXT: %18 = bitcast double* %call1.ptr3 to i8**
144-
; CHECK-NEXT: %19 = load i8*, i8** %18
139+
; CHECK-NEXT: %18 = bitcast double %call1 to i64
140+
; CHECK-NEXT: %19 = inttoptr i64 %18 to i8*
145141
; CHECK-NEXT: call void %insert_choice(i8* %trace, i8* nocapture readonly getelementptr inbounds ([2 x i8], [2 x i8]* @.str.2, i64 0, i64 0), double %likelihood.call1, i8* %19, i64 8)
146142
; CHECK-NEXT: %has.call.call2 = call i1 %has_call(i8* %observations, i8* nocapture readonly getelementptr inbounds ([21 x i8], [21 x i8]* @0, i32 0, i32 0))
147143
; CHECK-NEXT: br i1 %has.call.call2, label %condition.call2.with.trace, label %condition.call2.without.trace
@@ -156,9 +152,9 @@ entry:
156152
; CHECK-NEXT: br label %entry.cntd.cntd.cntd
157153

158154
; CHECK: entry.cntd.cntd.cntd: ; preds = %condition.call2.without.trace, %condition.call2.with.trace
159-
; CHECK-NEXT: %call24 = phi { double, i8* } [ %condition.calculate_loss, %condition.call2.with.trace ], [ %trace.calculate_loss, %condition.call2.without.trace ]
160-
; CHECK-NEXT: %call2 = extractvalue { double, i8* } %call24, 0
161-
; CHECK-NEXT: %newtrace.calculate_loss = extractvalue { double, i8* } %call24, 1
155+
; CHECK-NEXT: %call22 = phi { double, i8* } [ %condition.calculate_loss, %condition.call2.with.trace ], [ %trace.calculate_loss, %condition.call2.without.trace ]
156+
; CHECK-NEXT: %call2 = extractvalue { double, i8* } %call22, 0
157+
; CHECK-NEXT: %newtrace.calculate_loss = extractvalue { double, i8* } %call22, 1
162158
; CHECK-NEXT: call void %insert_call(i8* %trace, i8* nocapture readonly getelementptr inbounds ([21 x i8], [21 x i8]* @0, i32 0, i32 0), i8* %newtrace.calculate_loss)
163159
; CHECK-NEXT: %mrv = insertvalue { double, i8* } {{(undef|poison)}}, double %call2, 0
164160
; CHECK-NEXT: %mrv1 = insertvalue { double, i8* } %mrv, i8* %trace, 1
@@ -171,7 +167,6 @@ entry:
171167
; CHECK-NEXT: %0 = getelementptr inbounds i8*, i8** %interface, i32 3
172168
; CHECK-NEXT: %1 = load i8*, i8** %0
173169
; CHECK-NEXT: %insert_choice = bitcast i8* %1 to void (i8*, i8*, double, i8*, i64)*
174-
; CHECK-NEXT: %call.ptr2 = alloca double
175170
; CHECK-NEXT: %2 = getelementptr inbounds i8*, i8** %interface, i32 1
176171
; CHECK-NEXT: %3 = load i8*, i8** %2
177172
; CHECK-NEXT: %get_choice = bitcast i8* %3 to i64 (i8*, i8*, i8*, i64)*
@@ -219,9 +214,8 @@ entry:
219214
; CHECK: for.body.cntd: ; preds = %condition.call.without.trace, %condition.call.with.trace
220215
; CHECK-NEXT: %call = phi double [ %from.trace.call, %condition.call.with.trace ], [ %sample.call, %condition.call.without.trace ]
221216
; CHECK-NEXT: %likelihood.call = call double @normal_logpdf(double %9, double 1.000000e+00, double %call)
222-
; CHECK-NEXT: store double %call, double* %call.ptr2
223-
; CHECK-NEXT: %11 = bitcast double* %call.ptr2 to i8**
224-
; CHECK-NEXT: %12 = load i8*, i8** %11
217+
; CHECK-NEXT: %11 = bitcast double %call to i64
218+
; CHECK-NEXT: %12 = inttoptr i64 %11 to i8*
225219
; CHECK-NEXT: call void %insert_choice(i8* %trace, i8* nocapture readonly getelementptr inbounds ([11 x i8], [11 x i8]* @.str, i64 0, i64 0), double %likelihood.call, i8* %12, i64 8)
226220
; CHECK-NEXT: %arrayidx3 = getelementptr inbounds double, double* %data, i64 %indvars.iv
227221
; CHECK-NEXT: %13 = load double, double* %arrayidx3

enzyme/test/Enzyme/ProbProg/condition-static.ll

Lines changed: 9 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -80,9 +80,7 @@ entry:
8080

8181
; CHECK: define internal { double, i8* } @condition_loss(double* %data, i32 %n, i8* %observations)
8282
; CHECK-NEXT: entry:
83-
; CHECK-NEXT: %call1.ptr3 = alloca double
8483
; CHECK-NEXT: %call1.ptr = alloca double
85-
; CHECK-NEXT: %call.ptr2 = alloca double
8684
; CHECK-NEXT: %call.ptr = alloca double
8785
; CHECK-NEXT: %trace = call i8* @__enzyme_newtrace()
8886
; CHECK-NEXT: %has.choice.call = call i1 @__enzyme_has_choice(i8* %observations, i8* nocapture readonly getelementptr inbounds ([2 x i8], [2 x i8]* @.str.1, i64 0, i64 0))
@@ -101,9 +99,8 @@ entry:
10199
; CHECK: entry.cntd: ; preds = %condition.call.without.trace, %condition.call.with.trace
102100
; CHECK-NEXT: %call = phi double [ %from.trace.call, %condition.call.with.trace ], [ %sample.call, %condition.call.without.trace ]
103101
; CHECK-NEXT: %likelihood.call = call double @normal_logpdf(double 0.000000e+00, double 1.000000e+00, double %call)
104-
; CHECK-NEXT: store double %call, double* %call.ptr2
105-
; CHECK-NEXT: %1 = bitcast double* %call.ptr2 to i8**
106-
; CHECK-NEXT: %2 = load i8*, i8** %1
102+
; CHECK-NEXT: %1 = bitcast double %call to i64
103+
; CHECK-NEXT: %2 = inttoptr i64 %1 to i8*
107104
; CHECK-NEXT: call void @__enzyme_insert_choice(i8* %trace, i8* nocapture readonly getelementptr inbounds ([2 x i8], [2 x i8]* @.str.1, i64 0, i64 0), double %likelihood.call, i8* %2, i64 8)
108105
; CHECK-NEXT: %has.choice.call1 = call i1 @__enzyme_has_choice(i8* %observations, i8* nocapture readonly getelementptr inbounds ([2 x i8], [2 x i8]* @.str.2, i64 0, i64 0))
109106
; CHECK-NEXT: br i1 %has.choice.call1, label %condition.call1.with.trace, label %condition.call1.without.trace
@@ -121,9 +118,8 @@ entry:
121118
; CHECK: entry.cntd.cntd: ; preds = %condition.call1.without.trace, %condition.call1.with.trace
122119
; CHECK-NEXT: %call1 = phi double [ %from.trace.call1, %condition.call1.with.trace ], [ %sample.call1, %condition.call1.without.trace ]
123120
; CHECK-NEXT: %likelihood.call1 = call double @normal_logpdf(double 0.000000e+00, double 1.000000e+00, double %call1)
124-
; CHECK-NEXT: store double %call1, double* %call1.ptr3
125-
; CHECK-NEXT: %4 = bitcast double* %call1.ptr3 to i8**
126-
; CHECK-NEXT: %5 = load i8*, i8** %4
121+
; CHECK-NEXT: %4 = bitcast double %call1 to i64
122+
; CHECK-NEXT: %5 = inttoptr i64 %4 to i8*
127123
; CHECK-NEXT: call void @__enzyme_insert_choice(i8* %trace, i8* nocapture readonly getelementptr inbounds ([2 x i8], [2 x i8]* @.str.2, i64 0, i64 0), double %likelihood.call1, i8* %5, i64 8)
128124
; CHECK-NEXT: %has.call.call2 = call i1 @__enzyme_has_call(i8* %observations, i8* nocapture readonly getelementptr inbounds ([21 x i8], [21 x i8]* @0, i32 0, i32 0))
129125
; CHECK-NEXT: br i1 %has.call.call2, label %condition.call2.with.trace, label %condition.call2.without.trace
@@ -138,9 +134,9 @@ entry:
138134
; CHECK-NEXT: br label %entry.cntd.cntd.cntd
139135

140136
; CHECK: entry.cntd.cntd.cntd: ; preds = %condition.call2.without.trace, %condition.call2.with.trace
141-
; CHECK-NEXT: %call24 = phi { double, i8* } [ %condition.calculate_loss, %condition.call2.with.trace ], [ %trace.calculate_loss, %condition.call2.without.trace ]
142-
; CHECK-NEXT: %call2 = extractvalue { double, i8* } %call24, 0
143-
; CHECK-NEXT: %newtrace.calculate_loss = extractvalue { double, i8* } %call24, 1
137+
; CHECK-NEXT: %call22 = phi { double, i8* } [ %condition.calculate_loss, %condition.call2.with.trace ], [ %trace.calculate_loss, %condition.call2.without.trace ]
138+
; CHECK-NEXT: %call2 = extractvalue { double, i8* } %call22, 0
139+
; CHECK-NEXT: %newtrace.calculate_loss = extractvalue { double, i8* } %call22, 1
144140
; CHECK-NEXT: call void @__enzyme_insert_call(i8* %trace, i8* nocapture readonly getelementptr inbounds ([21 x i8], [21 x i8]* @0, i32 0, i32 0), i8* %newtrace.calculate_loss)
145141
; CHECK-NEXT: %mrv = insertvalue { double, i8* } {{(undef|poison)}}, double %call2, 0
146142
; CHECK-NEXT: %mrv1 = insertvalue { double, i8* } %mrv, i8* %trace, 1
@@ -150,7 +146,6 @@ entry:
150146

151147
; CHECK: define internal { double, i8* } @condition_calculate_loss(double %m, double %b, double* %data, i32 %n, i8* %observations)
152148
; CHECK-NEXT: entry:
153-
; CHECK-NEXT: %call.ptr2 = alloca double
154149
; CHECK-NEXT: %call.ptr = alloca double
155150
; CHECK-NEXT: %trace = call i8* @__enzyme_newtrace()
156151
; CHECK-NEXT: %cmp19 = icmp sgt i32 %n, 0
@@ -189,9 +184,8 @@ entry:
189184
; CHECK: for.body.cntd: ; preds = %condition.call.without.trace, %condition.call.with.trace
190185
; CHECK-NEXT: %call = phi double [ %from.trace.call, %condition.call.with.trace ], [ %sample.call, %condition.call.without.trace ]
191186
; CHECK-NEXT: %likelihood.call = call double @normal_logpdf(double %1, double 1.000000e+00, double %call)
192-
; CHECK-NEXT: store double %call, double* %call.ptr2
193-
; CHECK-NEXT: %3 = bitcast double* %call.ptr2 to i8**
194-
; CHECK-NEXT: %4 = load i8*, i8** %3
187+
; CHECK-NEXT: %3 = bitcast double %call to i64
188+
; CHECK-NEXT: %4 = inttoptr i64 %3 to i8*
195189
; CHECK-NEXT: call void @__enzyme_insert_choice(i8* %trace, i8* nocapture readonly getelementptr inbounds ([11 x i8], [11 x i8]* @.str, i64 0, i64 0), double %likelihood.call, i8* %4, i64 8)
196190
; CHECK-NEXT: %arrayidx3 = getelementptr inbounds double, double* %data, i64 %indvars.iv
197191
; CHECK-NEXT: %5 = load double, double* %arrayidx3

enzyme/test/Enzyme/ProbProg/simple-condition.ll

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -51,9 +51,7 @@ entry:
5151

5252
; CHECK: define internal i8* @condition_test(i8* %observations)
5353
; CHECK-NEXT: entry:
54-
; CHECK-NEXT: %x.ptr2 = alloca double
5554
; CHECK-NEXT: %x.ptr = alloca double
56-
; CHECK-NEXT: %mu.ptr1 = alloca double
5755
; CHECK-NEXT: %mu.ptr = alloca double
5856
; CHECK-NEXT: %trace = call i8* @__enzyme_newtrace()
5957
; CHECK-NEXT: %has.choice.mu = call i1 @__enzyme_has_choice(i8* %observations, i8* nocapture readonly getelementptr inbounds ([3 x i8], [3 x i8]* @.str, i64 0, i64 0))
@@ -72,9 +70,8 @@ entry:
7270
; CHECK: entry.cntd: ; preds = %condition.mu.without.trace, %condition.mu.with.trace
7371
; CHECK-NEXT: %mu = phi double [ %from.trace.mu, %condition.mu.with.trace ], [ %sample.mu, %condition.mu.without.trace ]
7472
; CHECK-NEXT: %likelihood.mu = call double @normal_logpdf(double 0.000000e+00, double 1.000000e+00, double %mu)
75-
; CHECK-NEXT: store double %mu, double* %mu.ptr1
76-
; CHECK-NEXT: %1 = bitcast double* %mu.ptr1 to i8**
77-
; CHECK-NEXT: %2 = load i8*, i8** %1
73+
; CHECK-NEXT: %1 = bitcast double %mu to i64
74+
; CHECK-NEXT: %2 = inttoptr i64 %1 to i8*
7875
; CHECK-NEXT: call void @__enzyme_insert_choice(i8* %trace, i8* nocapture readonly getelementptr inbounds ([3 x i8], [3 x i8]* @.str, i64 0, i64 0), double %likelihood.mu, i8* %2, i64 8)
7976
; CHECK-NEXT: %has.choice.x = call i1 @__enzyme_has_choice(i8* %observations, i8* nocapture readonly getelementptr inbounds ([2 x i8], [2 x i8]* @.str.1, i64 0, i64 0))
8077
; CHECK-NEXT: br i1 %has.choice.x, label %condition.x.with.trace, label %condition.x.without.trace
@@ -92,9 +89,8 @@ entry:
9289
; CHECK: entry.cntd.cntd: ; preds = %condition.x.without.trace, %condition.x.with.trace
9390
; CHECK-NEXT: %x = phi double [ %from.trace.x, %condition.x.with.trace ], [ %sample.x, %condition.x.without.trace ]
9491
; CHECK-NEXT: %likelihood.x = call double @normal_logpdf(double %mu, double 1.000000e+00, double %x)
95-
; CHECK-NEXT: store double %x, double* %x.ptr2
96-
; CHECK-NEXT: %4 = bitcast double* %x.ptr2 to i8**
97-
; CHECK-NEXT: %5 = load i8*, i8** %4
92+
; CHECK-NEXT: %4 = bitcast double %x to i64
93+
; CHECK-NEXT: %5 = inttoptr i64 %4 to i8*
9894
; CHECK-NEXT: call void @__enzyme_insert_choice(i8* %trace, i8* nocapture readonly getelementptr inbounds ([2 x i8], [2 x i8]* @.str.1, i64 0, i64 0), double %likelihood.x, i8* %5, i64 8)
9995
; CHECK-NEXT: ret i8* %trace
10096
; CHECK-NEXT: }

enzyme/test/Enzyme/ProbProg/simple-trace.ll

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -40,20 +40,16 @@ entry:
4040

4141
; CHECK: define internal i8* @trace_test()
4242
; CHECK-NEXT: entry:
43-
; CHECK-NEXT: %x.ptr = alloca double
44-
; CHECK-NEXT: %mu.ptr = alloca double
4543
; CHECK-NEXT: %trace = call i8* @__enzyme_newtrace()
4644
; CHECK-NEXT: %mu = call double @normal(double 0.000000e+00, double 1.000000e+00)
4745
; CHECK-NEXT: %likelihood.mu = call double @normal_logpdf(double 0.000000e+00, double 1.000000e+00, double %mu)
48-
; CHECK-NEXT: store double %mu, double* %mu.ptr
49-
; CHECK-NEXT: %0 = bitcast double* %mu.ptr to i8**
50-
; CHECK-NEXT: %1 = load i8*, i8** %0
46+
; CHECK-NEXT: %0 = bitcast double %mu to i64
47+
; CHECK-NEXT: %1 = inttoptr i64 %0 to i8*
5148
; CHECK-NEXT: call void @__enzyme_insert_choice(i8* %trace, i8* nocapture readonly getelementptr inbounds ([3 x i8], [3 x i8]* @.str, i64 0, i64 0), double %likelihood.mu, i8* %1, i64 8)
5249
; CHECK-NEXT: %x = call double @normal(double %mu, double 1.000000e+00)
5350
; CHECK-NEXT: %likelihood.x = call double @normal_logpdf(double %mu, double 1.000000e+00, double %x)
54-
; CHECK-NEXT: store double %x, double* %x.ptr
55-
; CHECK-NEXT: %2 = bitcast double* %x.ptr to i8**
56-
; CHECK-NEXT: %3 = load i8*, i8** %2
51+
; CHECK-NEXT: %2 = bitcast double %x to i64
52+
; CHECK-NEXT: %3 = inttoptr i64 %2 to i8*
5753
; CHECK-NEXT: call void @__enzyme_insert_choice(i8* %trace, i8* nocapture readonly getelementptr inbounds ([2 x i8], [2 x i8]* @.str.1, i64 0, i64 0), double %likelihood.x, i8* %3, i64 8)
5854
; CHECK-NEXT: ret i8* %trace
5955
; CHECK-NEXT: }

0 commit comments

Comments
 (0)