Skip to content

Commit 94eaab5

Browse files
authored
Allow func fields (#599)
1 parent f05ace9 commit 94eaab5

File tree

2 files changed

+82
-4
lines changed

2 files changed

+82
-4
lines changed

graphql_test.go

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5699,3 +5699,63 @@ func TestGraphqlNames(t *testing.T) {
56995699
},
57005700
})
57015701
}
5702+
5703+
func Test_fieldFunc(t *testing.T) {
5704+
sdl := `
5705+
type Query {
5706+
hello(name: String!): String!
5707+
}
5708+
`
5709+
gqltesting.RunTests(t, []*gqltesting.Test{
5710+
{
5711+
Schema: graphql.MustParseSchema(sdl,
5712+
func() interface{} {
5713+
type helloTagResolver struct {
5714+
Hello func(args struct{ Name string }) string
5715+
}
5716+
fn := func(args struct{ Name string }) string {
5717+
return "Hello, " + args.Name + "!"
5718+
}
5719+
return &helloTagResolver{
5720+
Hello: fn,
5721+
}
5722+
}(),
5723+
graphql.UseFieldResolvers()),
5724+
Query: `
5725+
{
5726+
hello(name: "GraphQL")
5727+
}
5728+
`,
5729+
ExpectedResult: `
5730+
{
5731+
"hello": "Hello, GraphQL!"
5732+
}
5733+
`,
5734+
},
5735+
{
5736+
Schema: graphql.MustParseSchema(sdl,
5737+
func() interface{} {
5738+
type helloTagResolver struct {
5739+
Greet func(ctx context.Context, args struct{ Name string }) (string, error) `graphql:"hello"`
5740+
}
5741+
fn := func(_ context.Context, args struct{ Name string }) (string, error) {
5742+
return "Hello, " + args.Name + "!", nil
5743+
}
5744+
return &helloTagResolver{
5745+
Greet: fn,
5746+
}
5747+
}(),
5748+
graphql.UseFieldResolvers()),
5749+
Query: `
5750+
{
5751+
hello(name: "GraphQL")
5752+
}
5753+
`,
5754+
ExpectedResult: `
5755+
{
5756+
"hello": "Hello, GraphQL!"
5757+
}
5758+
`,
5759+
},
5760+
})
5761+
}

internal/exec/resolvable/resolvable.go

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ type Field struct {
4747
FieldIndex []int
4848
HasContext bool
4949
HasError bool
50+
IsFieldFunc bool
5051
ArgsPacker *packer.StructPacker
5152
Visitors *FieldVisitors
5253
ValueExec Resolvable
@@ -59,7 +60,7 @@ type FieldVisitors struct {
5960
}
6061

6162
func (f *Field) UseMethodResolver() bool {
62-
return len(f.FieldIndex) == 0
63+
return f.MethodIndex != -1 || f.IsFieldFunc
6364
}
6465

6566
func (f *Field) Resolve(ctx context.Context, resolver reflect.Value, args interface{}) (output interface{}, err error) {
@@ -126,7 +127,15 @@ func (f *Field) resolve(ctx context.Context, resolver reflect.Value, args interf
126127
in = append(in, reflect.ValueOf(args))
127128
}
128129

129-
callOut = resolver.Method(f.MethodIndex).Call(in)
130+
if f.IsFieldFunc { // resolver is a struct field of type func
131+
res := resolver
132+
if res.Kind() == reflect.Pointer {
133+
res = resolver.Elem()
134+
}
135+
callOut = res.FieldByIndex(f.FieldIndex).Call(in)
136+
} else {
137+
callOut = resolver.Method(f.MethodIndex).Call(in)
138+
}
130139
result := callOut[0]
131140

132141
if f.HasError && !callOut[1].IsNil() {
@@ -537,9 +546,17 @@ func (b *execBuilder) makeFieldExec(typeName string, f *ast.FieldDefinition, m r
537546
var argsPacker *packer.StructPacker
538547
var hasError bool
539548
var hasContext bool
549+
var isFieldFunc bool
540550

551+
if methodIndex == -1 && len(fieldIndex) > 0 {
552+
if sf.Type.Kind() == reflect.Func {
553+
m.Type = sf.Type
554+
methodHasReceiver = false
555+
isFieldFunc = true
556+
}
557+
}
541558
// Validate resolver method only when there is one
542-
if methodIndex != -1 {
559+
if methodIndex != -1 || isFieldFunc {
543560
in := make([]reflect.Type, m.Type.NumIn())
544561
for i := range in {
545562
in[i] = m.Type.In(i)
@@ -596,6 +613,7 @@ func (b *execBuilder) makeFieldExec(typeName string, f *ast.FieldDefinition, m r
596613
TypeName: typeName,
597614
MethodIndex: methodIndex,
598615
FieldIndex: fieldIndex,
616+
IsFieldFunc: isFieldFunc,
599617
HasContext: hasContext,
600618
ArgsPacker: argsPacker,
601619
Visitors: visitors,
@@ -604,7 +622,7 @@ func (b *execBuilder) makeFieldExec(typeName string, f *ast.FieldDefinition, m r
604622
}
605623

606624
var out reflect.Type
607-
if methodIndex != -1 {
625+
if methodIndex != -1 || isFieldFunc {
608626
out = m.Type.Out(0)
609627
sub, ok := b.schema.RootOperationTypes["subscription"]
610628
if ok && typeName == sub.TypeName() && out.Kind() == reflect.Chan {

0 commit comments

Comments
 (0)