Skip to content

Commit 63e7039

Browse files
committed
breaking change: introduce OverridesError() to mark unsafe wrappers
1 parent 605cbb7 commit 63e7039

File tree

4 files changed

+138
-0
lines changed

4 files changed

+138
-0
lines changed

bind.go

+5
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,11 @@ func doBind(sc *Collection, originalInvokeF *provider, originalInitF *provider,
3434
return err
3535
}
3636

37+
err = checkForMissingOverridesError(afterInvoke)
38+
if err != nil {
39+
return err
40+
}
41+
3742
// Add debugging provider
3843
{
3944
d := newProvider(func() *Debugging { return nil }, -1, "Debugging")

nject.go

+7
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ type provider struct {
2525
callsInner bool
2626
memoize bool
2727
loose bool
28+
overridesError bool
2829
desired bool
2930
shun bool
3031
notCacheable bool
@@ -363,6 +364,9 @@ func (fm provider) DownFlows() ([]reflect.Type, []reflect.Type) {
363364
}
364365
t := v.Type()
365366
if t.Kind() == reflect.Func {
367+
if fm.group == finalGroup {
368+
return typesIn(t), nil
369+
}
366370
return effectiveOutputs(t)
367371
}
368372
return nil, []reflect.Type{t}
@@ -439,6 +443,9 @@ func (fm provider) UpFlows() ([]reflect.Type, []reflect.Type) {
439443
}
440444
t := v.Type()
441445
if t.Kind() == reflect.Func {
446+
if fm.group == finalGroup {
447+
return nil, typesOut(t)
448+
}
442449
return effectiveReturns(t)
443450
}
444451
return nil, []reflect.Type{t}

overrides_error.go

+86
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
package nject
2+
3+
import (
4+
"fmt"
5+
"reflect"
6+
)
7+
8+
// OverridesError marks a provider that is okay for that provider to override
9+
// error returns. Without this decorator, a wrapper that returns error but
10+
// does not expect to receive an error will cause the injection chain
11+
// compilation to fail.
12+
//
13+
// A common mistake is to have an wrapper that accidently returns error. It
14+
// looks like this:
15+
//
16+
// func AutoCloseThing(inner func(someType), param anotherType) error {
17+
// thing, err := getThing(param)
18+
// if err != nil {
19+
// return err
20+
// }
21+
// defer thing.Close()
22+
// inner(thing)
23+
// return nil
24+
// }
25+
//
26+
// The above function has two problems. The big problem is that it will
27+
// override any returned errors coming up from below in the call chain
28+
// by returning nil. The fix for this is to have the inner function return
29+
// error. If you aren't sure there will be something below that will
30+
// definitely return error, then you can inject something to provide a nil
31+
// error. Put the following at the end of the sequence:
32+
//
33+
// nject.Shun(nject.NotFinal(func () error { return nil }))
34+
//
35+
// The second issue is that thing.Close() probably returns error. A correct
36+
// wrapper for this looks like this:
37+
//
38+
// func AutoCloseThing(inner func(someType) error, param anotherType) (err error) {
39+
// var thing someType
40+
// thing, err = getThing(param)
41+
// if err != nil {
42+
// return err
43+
// }
44+
// defer func() {
45+
// e := thing.Close()
46+
// if err == nil && e != nil {
47+
// err = e
48+
// }
49+
// }()
50+
// return inner(thing)
51+
// }
52+
//
53+
func OverridesError(fn interface{}) Provider {
54+
return newThing(fn).modify(func(fm *provider) {
55+
fm.overridesError = true
56+
})
57+
}
58+
59+
func checkForMissingOverridesError(collection []*provider) error {
60+
var errorReturnSeen bool
61+
for i := len(collection) - 1; i >= 0; i-- {
62+
fm := collection[i]
63+
if errorReturnSeen && !fm.overridesError && fm.class == wrapperFunc {
64+
consumes, returns := fm.UpFlows()
65+
if hasError(returns) && !hasError(consumes) {
66+
return fmt.Errorf("wrapper returns error but does not consume error. Decorate with OverridesError() if this is intentional. %s", fm)
67+
}
68+
}
69+
if !errorReturnSeen {
70+
_, returns := fm.UpFlows()
71+
if hasError(returns) {
72+
errorReturnSeen = true
73+
}
74+
}
75+
}
76+
return nil
77+
}
78+
79+
func hasError(types []reflect.Type) bool {
80+
for _, typ := range types {
81+
if typ == errorType {
82+
return true
83+
}
84+
}
85+
return false
86+
}

overrides_error_test.go

+40
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
package nject
2+
3+
import (
4+
"testing"
5+
6+
"github.com/stretchr/testify/assert"
7+
)
8+
9+
func TestOverridesError(t *testing.T) {
10+
type someType string
11+
type anotherType string
12+
getThing := func(_ anotherType) (someType, error) { return "", nil }
13+
danger := Required(func(inner func(someType), param anotherType) error {
14+
thing, err := getThing(param)
15+
if err != nil {
16+
return err
17+
}
18+
inner(thing)
19+
return nil
20+
})
21+
finalWithError := func() error { return nil }
22+
returnsTerminal := Required(func() TerminalError { return nil })
23+
finalWithoutError := func() {}
24+
var target func(anotherType) error
25+
26+
t.Log("test: okay because no error bubbling up")
27+
assert.NoError(t, Sequence("A", danger, finalWithoutError).Bind(&target, nil))
28+
29+
t.Log("test: should fail because the final function returns error that gets clobbered")
30+
assert.Error(t, Sequence("B", danger, finalWithError).Bind(&target, nil))
31+
32+
t.Log("test: should fail because there is a terminal-error injector that gets clobbered")
33+
assert.Error(t, Sequence("C", danger, returnsTerminal, finalWithoutError).Bind(&target, nil))
34+
35+
t.Log("test: okay because marked even thoguh the final function returns error that gets clobbered")
36+
assert.Error(t, Sequence("B", OverridesError(danger), finalWithError).Bind(&target, nil))
37+
38+
t.Log("test: okay because marked even though there is a terminal-error injector that gets clobbered")
39+
assert.Error(t, Sequence("C", OverridesError(danger), returnsTerminal, finalWithoutError).Bind(&target, nil))
40+
}

0 commit comments

Comments
 (0)