Skip to content

Commit eb16c23

Browse files
authored
Add a Reorder() decorator (#31)
* WIP * WIP, 2 tests pass * reorder now works * bugfix: eliminate shun functions first * lint
1 parent c8be771 commit eb16c23

12 files changed

+667
-21
lines changed

api.go

+8-1
Original file line numberDiff line numberDiff line change
@@ -334,7 +334,14 @@ func NonFinal(fn interface{}) Provider {
334334
// head of the provider chain. The static portion of the provider
335335
// chain will run once. The values returned from the initialization
336336
// function come from the values available after the static portion
337-
// of the provider chain runs.
337+
// of the provider chain runs. For example, if the static portion
338+
// of an injection chain consists of:
339+
//
340+
// func(int) string { ... }
341+
// func(string) int64 { ... }
342+
//
343+
// Then the return value from the initialization could include int,
344+
// int64, and string but no other types.
338345
//
339346
// Bind pre-computes as much as possible so that the invokeFunc is fast.
340347
//

bind.go

+3-3
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,7 @@ func doBind(sc *Collection, originalInvokeF *provider, originalInitF *provider,
6464
funcs = append(funcs, invokeF)
6565
funcs = append(funcs, afterInvoke...)
6666

67-
for i, fm := range funcs {
68-
fm.chainPosition = i
67+
for _, fm := range funcs {
6968
if fm.required {
7069
fm.include = true
7170
}
@@ -93,7 +92,8 @@ func doBind(sc *Collection, originalInvokeF *provider, originalInitF *provider,
9392

9493
// Compute dependencies: set fm.downRmap, fm.upRmap, fm.cannotInclude,
9594
// fm.whyIncluded, fm.include
96-
err := computeDependenciesAndInclusion(funcs, initF)
95+
var err error
96+
funcs, err = computeDependenciesAndInclusion(funcs, initF)
9797
if err != nil {
9898
return err
9999
}

characterize.go

+10-2
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,7 @@ var (
145145
markedMemoized = predicate("is not marked Memoized", func(a testArgs) bool { return a.fm.memoize })
146146
markedCacheable = predicate("is not marked Cacheable", func(a testArgs) bool { return a.fm.cacheable })
147147
markedSingleton = predicate("is not marked Singleton", func(a testArgs) bool { return a.fm.singleton })
148+
notMarkedReorder = predicate("is marked Reorder", func(a testArgs) bool { return !a.fm.reorder })
148149
notMarkedSingleton = predicate("is marked Singleton", func(a testArgs) bool { return !a.fm.singleton })
149150
notMarkedNoCache = predicate("is marked NotCacheable", func(a testArgs) bool { return !a.fm.notCacheable })
150151
mappableInputs = predicate("has inputs that cannot be map keys", func(a testArgs) bool { return mappable(typesIn(a.t)...) })
@@ -244,6 +245,7 @@ var handlerRegistry = typeRegistry{
244245
mappableInputs,
245246
notMarkedNoCache,
246247
mustNotMemoize,
248+
notMarkedReorder,
247249
},
248250
mutate: func(a testArgs) {
249251
a.fm.group = staticGroup
@@ -265,6 +267,7 @@ var handlerRegistry = typeRegistry{
265267
mappableInputs,
266268
notMarkedNoCache,
267269
mustNotMemoize,
270+
notMarkedReorder,
268271
},
269272
mutate: func(a testArgs) {
270273
a.fm.group = staticGroup
@@ -288,6 +291,7 @@ var handlerRegistry = typeRegistry{
288291
notMarkedNoCache,
289292
possibleMapKey,
290293
notMarkedSingleton,
294+
notMarkedReorder,
291295
},
292296
mutate: func(a testArgs) {
293297
a.fm.group = staticGroup
@@ -313,6 +317,7 @@ var handlerRegistry = typeRegistry{
313317
notMarkedNoCache,
314318
possibleMapKey,
315319
notMarkedSingleton,
320+
notMarkedReorder,
316321
},
317322
mutate: func(a testArgs) {
318323
a.fm.group = staticGroup
@@ -336,6 +341,7 @@ var handlerRegistry = typeRegistry{
336341
mustNotMemoize,
337342
notMarkedNoCache,
338343
notMarkedSingleton,
344+
notMarkedReorder,
339345
},
340346
mutate: func(a testArgs) {
341347
a.fm.group = staticGroup
@@ -504,8 +510,10 @@ func (reg typeRegistry) characterizeFuncDetails(fm *provider, cc charContext) (*
504510
var isNil bool
505511
// nolint:exhaustive
506512
switch v.Type().Kind() {
507-
case reflect.Chan, reflect.Func, reflect.Interface, reflect.Map, reflect.Ptr, reflect.Slice:
513+
case reflect.Chan, reflect.Interface, reflect.Map, reflect.Ptr, reflect.Slice:
508514
isNil = v.IsNil()
515+
default:
516+
isNil = false
509517
}
510518
a = testArgs{
511519
fm: fm.copy(),
@@ -531,7 +539,7 @@ Match:
531539
}
532540

533541
// panic(fmt.Sprintf("%s: %s - %s", fm.describe(), t, strings.Join(rejectReasons, "; ")))
534-
return nil, fm.errorf("Could not type %s to any prototype: %s", a.t, strings.Join(rejectReasons, "; "))
542+
return nil, fm.errorf("Could not match type %s to any prototype: %s", a.t, strings.Join(rejectReasons, "; "))
535543
}
536544

537545
func characterizeInitInvoke(fm *provider, context charContext) (*provider, error) {

debug.go

+12-4
Original file line numberDiff line numberDiff line change
@@ -188,16 +188,24 @@ func formatFlow(flow []typeCode) string {
188188
return strings.Join(types, ", ")
189189
}
190190

191+
func elem(i interface{}) reflect.Type {
192+
t := reflect.TypeOf(i)
193+
if t.Kind() == reflect.Ptr {
194+
return t.Elem()
195+
}
196+
return t
197+
}
198+
191199
func generateReproduce(funcs []*provider, invokeF *provider, initF *provider) string {
192200
subs := make(map[typeCode]string)
193201
t := ""
194202
f := "func TestRegression(t *testing.T) {\n"
195203
f += "\twrapTest(t, func(t *testing.T) {\n"
196204
f += "\t\tcalled := make(map[string]int)\n"
197-
f += "\t\tvar invoker " + funcSig(subs, &t, reflect.TypeOf(invokeF.fn).Elem()) + "\n"
205+
f += "\t\tvar invoker " + funcSig(subs, &t, elem(invokeF.fn)) + "\n"
198206
initName := "nil"
199207
if initF != nil {
200-
f += "\t\tvar initer " + funcSig(subs, &t, reflect.TypeOf(initF.fn).Elem()) + "\n"
208+
f += "\t\tvar initer " + funcSig(subs, &t, elem(initF.fn)) + "\n"
201209
initName = "&initer"
202210
}
203211
f += "\t\trequire.NoError(t,\n"
@@ -277,9 +285,9 @@ func generateReproduce(funcs []*provider, invokeF *provider, initF *provider) st
277285
}
278286
f += "\t\t\t).Bind(&invoker, " + initName + "))\n"
279287
if initF != nil {
280-
f += "\t\tiniter(" + strings.Join(substituteDefaults(subs, typesIn(reflect.TypeOf(initF.fn).Elem())), ", ") + ")\n"
288+
f += "\t\tiniter(" + strings.Join(substituteDefaults(subs, typesIn(elem(initF.fn))), ", ") + ")\n"
281289
}
282-
f += "\t\tinvoker(" + strings.Join(substituteDefaults(subs, typesIn(reflect.TypeOf(invokeF.fn).Elem())), ", ") + ")\n"
290+
f += "\t\tinvoker(" + strings.Join(substituteDefaults(subs, typesIn(elem(invokeF.fn))), ", ") + ")\n"
283291
f += "\t})\n"
284292
f += "}\n"
285293
return t + "\n" + f

example_provider_test.go

+25
Original file line numberDiff line numberDiff line change
@@ -110,3 +110,28 @@ func ExampleNonFinal() {
110110
// Output: final 20 some string
111111
// <nil>
112112
}
113+
114+
// This demonstrates how it to have a default that gets overridden by
115+
// by later inputs.
116+
func ExampleReorder() {
117+
seq1 := nject.Sequence("example",
118+
nject.Shun(func() string {
119+
fmt.Println("fallback default included")
120+
return "fallback default"
121+
}),
122+
)
123+
seq2 := nject.Sequence("later inputs",
124+
nject.Reorder(func() string {
125+
return "override value"
126+
}),
127+
)
128+
fmt.Println(nject.Run("combination",
129+
seq1,
130+
seq2,
131+
func(s string) {
132+
fmt.Println(s)
133+
},
134+
))
135+
// Output: override value
136+
// <nil>
137+
}

include.go

+19-8
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ type includeWorkingData struct {
1414
excluded error
1515
clusterMembers []*provider
1616
wantedInCluster bool
17+
hasFlow [lastFlowType]func(typeCode) bool // populated and used in reorder
1718
}
1819

1920
//
@@ -64,7 +65,12 @@ type includeWorkingData struct {
6465
// fm.wanted
6566
//
6667

67-
func computeDependenciesAndInclusion(funcs []*provider, initF *provider) error {
68+
func computeDependenciesAndInclusion(funcs []*provider, initF *provider) ([]*provider, error) {
69+
var err error
70+
funcs = reorder(funcs, initF)
71+
for i, fm := range funcs {
72+
fm.chainPosition = i
73+
}
6874
debugln("initial set of functions")
6975
clusterLeaders := make(map[int32]*provider)
7076
for _, fm := range funcs {
@@ -98,15 +104,15 @@ func computeDependenciesAndInclusion(funcs []*provider, initF *provider) error {
98104
}
99105
}
100106
debugln("calculate flows, initial")
101-
err := providesReturns(funcs, initF)
107+
err = providesReturns(funcs, initF)
102108
if err != nil {
103-
return err
109+
return nil, err
104110
}
105111

106112
debugln("check chain validity, no provider excluded")
107113
err = validateChainMarkIncludeExclude(funcs, true)
108114
if err != nil {
109-
return err
115+
return nil, err
110116
}
111117

112118
for _, fm := range funcs {
@@ -187,15 +193,15 @@ func computeDependenciesAndInclusion(funcs []*provider, initF *provider) error {
187193
debugln("final calculate flows")
188194
err = providesReturns(funcs, initF)
189195
if err != nil {
190-
return fmt.Errorf("internal error: uh oh")
196+
return nil, fmt.Errorf("internal error: uh oh")
191197
}
192198
debugf("final check chain validity")
193199
err = validateChainMarkIncludeExclude(funcs, true)
194200
if err != nil {
195-
return fmt.Errorf("internal error: uh oh #2")
201+
return nil, fmt.Errorf("internal error: uh oh #2")
196202
}
197203

198-
return nil
204+
return funcs, nil
199205
}
200206

201207
func validateChainMarkIncludeExclude(funcs []*provider, canRemoveDesired bool) error {
@@ -508,8 +514,13 @@ func proposeEliminations(funcs []*provider) []*provider {
508514
}
509515
}
510516
proposal := make([]*provider, 0, len(funcs))
517+
for _, fm := range funcs {
518+
if fm.shun {
519+
proposal = append(proposal, fm)
520+
}
521+
}
511522
for i, fm := range funcs {
512-
if !kept[i] || fm.shun {
523+
if !kept[i] && !fm.shun {
513524
proposal = append(proposal, fm)
514525
}
515526
}

intheap.go

+23
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
package nject
2+
3+
// Code below originated with the container/heap documentation
4+
5+
type IntHeap []int
6+
7+
func (h IntHeap) Len() int { return len(h) }
8+
func (h IntHeap) Less(i, j int) bool { return h[i] < h[j] }
9+
func (h IntHeap) Swap(i, j int) { h[i], h[j] = h[j], h[i] }
10+
11+
func (h *IntHeap) Push(x interface{}) {
12+
// Push and Pop use pointer receivers because they modify the slice's length,
13+
// not just its contents.
14+
*h = append(*h, x.(int))
15+
}
16+
17+
func (h *IntHeap) Pop() interface{} {
18+
old := *h
19+
n := len(old)
20+
x := old[n-1]
21+
*h = old[0 : n-1]
22+
return x
23+
}

nject.go

+3-1
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ type provider struct {
2525
callsInner bool
2626
memoize bool
2727
loose bool
28+
reorder bool
2829
overridesError bool
2930
desired bool
3031
shun bool
@@ -51,9 +52,9 @@ type provider struct {
5152
bypassRmap map[typeCode]typeCode // overrides types of returning parameters
5253
include bool
5354
d includeWorkingData
55+
chainPosition int
5456

5557
// added during binding
56-
chainPosition int
5758
mustZeroIfRemainderSkipped []typeCode
5859
mustZeroIfInnerNotCalled []typeCode
5960
vmapCount int
@@ -84,6 +85,7 @@ func (fm *provider) copy() *provider {
8485
callsInner: fm.callsInner,
8586
memoize: fm.memoize,
8687
loose: fm.loose,
88+
reorder: fm.reorder,
8789
desired: fm.desired,
8890
shun: fm.shun,
8991
notCacheable: fm.notCacheable,

0 commit comments

Comments
 (0)