Skip to content

Add graphql reflect tag #596

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Feb 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 56 additions & 0 deletions examples_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -433,3 +433,59 @@ func ExampleUseStringDescriptions() {
// field: "title", description: ""
// field: "tags", description: "Tags of the post."
}

// ExampleFieldTag demonstrates the use of the graphql field tag.
func Example_resolverFieldTag() {
type resolver struct {
Hello string
HelloUnderscore string `graphql:"_hello"`
HelloLower string `graphql:"hello"`
HelloTitle string `graphql:"Hello"`
HelloUpper string `graphql:"HELLO"`
}

sdl := `
type Query {
_hello: String!
hello: String!
Hello: String!
HELLO: String!
}`

r := &resolver{
Hello: "This field is not used during query execution!",
HelloLower: "Hello, graphql!",
HelloTitle: "Hello, GraphQL!",
HelloUnderscore: "Hello, _!",
HelloUpper: "Hello, GRAPHQL!",
}

query := `
{
_hello
hello
Hello
HELLO
}
`

schema := graphql.MustParseSchema(sdl, r, graphql.UseFieldResolvers())
res := schema.Exec(context.Background(), query, "", nil)

enc := json.NewEncoder(os.Stdout)
enc.SetIndent("", " ")
err := enc.Encode(res)
if err != nil {
panic(err)
}

// output:
// {
// "data": {
// "_hello": "Hello, _!",
// "hello": "Hello, graphql!",
// "Hello": "Hello, GraphQL!",
// "HELLO": "Hello, GRAPHQL!"
// }
// }
}
67 changes: 66 additions & 1 deletion graphql_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2529,7 +2529,7 @@ func TestInlineFragments(t *testing.T) {
},

{
Schema: socialSchema,
Schema: graphql.MustParseSchema(social.Schema, &social.Resolver{}, graphql.UseFieldResolvers()),
Query: `
query {
admin(id: "0x01") {
Expand Down Expand Up @@ -5634,3 +5634,68 @@ func TestSchemaExtension(t *testing.T) {
t.Fatalf(`expected an "awesome" schema directive, got %q`, dirs[0].Name.Name)
}
}

func TestGraphqlNames(t *testing.T) {
t.Parallel()

sdl1 := `
type Query {
hello: String!
}
`
type invalidResolver1 struct {
Field1 string `graphql:"hello"`
Field2 string `graphql:"hello"`
}

wantErr := fmt.Errorf(`*graphql_test.invalidResolver1 does not resolve "Query": multiple fields have a graphql reflect tag "hello"`)
_, err := graphql.ParseSchema(sdl1, &invalidResolver1{}, graphql.UseFieldResolvers())
if err == nil || err.Error() != wantErr.Error() {
t.Fatalf("want err %q, got %q", wantErr, err)
}

gqltesting.RunTests(t, []*gqltesting.Test{
{
Schema: graphql.MustParseSchema(`
type Query {
_hello: String!
hello: String!
Hello: String!
HELLO: String!
}`,
func() interface{} {
type helloTagResolver struct {
Hello string
HelloUnderscore string `graphql:"_hello"`
HelloLower string `graphql:"hello"`
HelloTitle string `graphql:"Hello"`
HelloUpper string `graphql:"HELLO"`
}
return &helloTagResolver{
Hello: "This field will not be used during query execution!",
HelloLower: "Hello, graphql!",
HelloTitle: "Hello, GraphQL!",
HelloUnderscore: "Hello, _!",
HelloUpper: "Hello, GRAPHQL!",
}
}(),
graphql.UseFieldResolvers()),
Query: `
{
_hello
hello
Hello
HELLO
}
`,
ExpectedResult: `
{
"_hello": "Hello, _!",
"hello": "Hello, graphql!",
"Hello": "Hello, GraphQL!",
"HELLO": "Hello, GRAPHQL!"
}
`,
},
})
}
56 changes: 42 additions & 14 deletions internal/exec/resolvable/resolvable.go
Original file line number Diff line number Diff line change
Expand Up @@ -444,18 +444,22 @@ func (b *execBuilder) makeObjectExec(typeName string, fields ast.FieldsDefinitio

Fields := make(map[string]*Field)
rt := unwrapPtr(resolverType)
fieldsCount := fieldCount(rt, map[string]int{})
fieldsCount, fieldTagsCount := fieldCount(rt, map[string]int{}, map[string]int{})
for _, f := range fields {
var fieldIndex []int
methodIndex := findMethod(resolverType, f.Name)
if b.useFieldResolvers && methodIndex == -1 {
if fieldsCount[strings.ToLower(stripUnderscore(f.Name))] > 1 {
// If a resolver field is ambiguous thrown an error unless there is exactly one field with the given graphql
// reflect tag. In that case use the field with the reflect tag.
if fieldTagsCount[f.Name] > 1 {
return nil, fmt.Errorf("%s does not resolve %q: multiple fields have a graphql reflect tag %q", resolverType, typeName, f.Name)
} else if fieldsCount[strings.ToLower(stripUnderscore(f.Name))] > 1 && fieldTagsCount[f.Name] != 1 {
return nil, fmt.Errorf("%s does not resolve %q: ambiguous field %q", resolverType, typeName, f.Name)
}
fieldIndex = findField(rt, f.Name, []int{})
fieldIndex = findField(rt, f.Name, []int{}, fieldTagsCount)
}
if methodIndex == -1 && len(fieldIndex) == 0 {
hint := ""
var hint string
if findMethod(reflect.PtrTo(resolverType), f.Name) != -1 {
hint = " (hint: the method exists on the pointer type)"
}
Expand Down Expand Up @@ -529,9 +533,7 @@ func (b *execBuilder) makeObjectExec(typeName string, fields ast.FieldsDefinitio
var contextType = reflect.TypeOf((*context.Context)(nil)).Elem()
var errorType = reflect.TypeOf((*error)(nil)).Elem()

func (b *execBuilder) makeFieldExec(typeName string, f *ast.FieldDefinition, m reflect.Method, sf reflect.StructField,
methodIndex int, fieldIndex []int, methodHasReceiver bool) (*Field, error) {

func (b *execBuilder) makeFieldExec(typeName string, f *ast.FieldDefinition, m reflect.Method, sf reflect.StructField, methodIndex int, fieldIndex []int, methodHasReceiver bool) (*Field, error) {
var argsPacker *packer.StructPacker
var hasError bool
var hasContext bool
Expand Down Expand Up @@ -662,17 +664,29 @@ func findMethod(t reflect.Type, name string) int {
return -1
}

func findField(t reflect.Type, name string, index []int) []int {
func findField(t reflect.Type, name string, index []int, matchingTagsCount map[string]int) []int {
for i := 0; i < t.NumField(); i++ {
field := t.Field(i)

if field.Type.Kind() == reflect.Struct && field.Anonymous {
newIndex := findField(field.Type, name, []int{i})
newIndex := findField(field.Type, name, []int{i}, matchingTagsCount)
if len(newIndex) > 1 {
return append(index, newIndex...)
}
}

if gt, ok := field.Tag.Lookup("graphql"); ok {
if name == gt {
return append(index, i)
}
}

// The current field's tag didn't match, however, if the tag of another field matches,
// then skip the name matching until we find the desired field with the correct tag.
if matchingTagsCount[name] > 0 {
continue
}

if strings.EqualFold(stripUnderscore(name), stripUnderscore(field.Name)) {
return append(index, i)
}
Expand All @@ -682,26 +696,40 @@ func findField(t reflect.Type, name string, index []int) []int {
}

// fieldCount helps resolve ambiguity when more than one embedded struct contains fields with the same name.
func fieldCount(t reflect.Type, count map[string]int) map[string]int {
// or when a field has a `graphql` reflect tag with the same name as some other field causing name collision.
func fieldCount(t reflect.Type, count, tagsCount map[string]int) (map[string]int, map[string]int) {
if t.Kind() != reflect.Struct {
return nil
return nil, nil
}

for i := 0; i < t.NumField(); i++ {
field := t.Field(i)
fieldName := strings.ToLower(stripUnderscore(field.Name))
var fieldName, gt string
var hasTag bool
if gt, hasTag = field.Tag.Lookup("graphql"); hasTag && gt != "" {
fieldName = gt
} else {
fieldName = strings.ToLower(stripUnderscore(field.Name))
}

if field.Type.Kind() == reflect.Struct && field.Anonymous {
count = fieldCount(field.Type, count)
count, tagsCount = fieldCount(field.Type, count, tagsCount)
} else {
if _, ok := count[fieldName]; !ok {
count[fieldName] = 0
}
count[fieldName]++
if !hasTag {
continue
}
if _, ok := count[gt]; !ok {
tagsCount[gt] = 0
}
tagsCount[gt]++
}
}

return count
return count, tagsCount
}

func unwrapNonNull(t ast.Type) (ast.Type, bool) {
Expand Down
4 changes: 1 addition & 3 deletions introspection_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@ import (
"github.com/graph-gophers/graphql-go/example/starwars"
)

var socialSchema = graphql.MustParseSchema(social.Schema, &social.Resolver{}, graphql.UseFieldResolvers())

func TestSchema_ToJSON(t *testing.T) {
t.Parallel()

Expand All @@ -29,7 +27,7 @@ func TestSchema_ToJSON(t *testing.T) {
}{
{
Name: "Social Schema",
Args: args{Schema: socialSchema},
Args: args{Schema: graphql.MustParseSchema(social.Schema, &social.Resolver{}, graphql.UseFieldResolvers())},
Want: want{JSON: mustReadFile("example/social/introspect.json")},
},
{
Expand Down