Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 41 additions & 23 deletions callbacks.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,35 +68,53 @@ func getMethod(value reflect.Value, name string) reflect.Value {
return method
}

// Get methods from the given value and any embedded fields.
// getMethods gets all methods with the given name from the given value
// and any embedded fields.
//
// Returns a slice of bound methods that can be called directly.
func getMethods(value reflect.Value, name string) []reflect.Value {
// Collect all possible receivers
receivers := []reflect.Value{value}
if value.Kind() == reflect.Ptr {
value = value.Elem()
}
if value.Kind() == reflect.Struct {
t := value.Type()
for i := 0; i < value.NumField(); i++ {
field := value.Field(i)
fieldType := t.Field(i)
if !fieldType.IsExported() {
continue
}
// Traverses embedded fields of the struct
// starting from the given value to collect all possible receivers
// for the given method name.
var traverse func(value reflect.Value, receivers []reflect.Value) []reflect.Value
traverse = func(value reflect.Value, receivers []reflect.Value) []reflect.Value {
Comment on lines +79 to +80
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One thing worth noting:
Adding a receivers parameter here allows the recursive calls to reuse the slice's backing storage
so that when one append gives the slice 2x capacity, that is re-used in following appends
before needing another resize further down the stack.

We could change the nil below to something like make([]reflect.Value, 0, 16), and this iteration/collection would be alloc-free for a recursive depth of 16 (which should be more than enough for 90% of the use cases).

// Always consider the current value for hooks.
receivers = append(receivers, value)

// Hooks on exported embedded fields should be called.
if fieldType.Anonymous {
receivers = append(receivers, field)
continue
}
if value.Kind() == reflect.Ptr {
value = value.Elem()
}

// Hooks on exported fields that are not exported,
// but are tagged with `embed:""` should be called.
if _, ok := fieldType.Tag.Lookup("embed"); ok {
receivers = append(receivers, field)
// If the current value is a struct, also consider embedded fields.
// Two kinds of embedded fields are considered if they're exported:
//
// - standard Go embedded fields
// - fields tagged with `embed:""`
if value.Kind() == reflect.Struct {
t := value.Type()
for i := 0; i < value.NumField(); i++ {
fieldValue := value.Field(i)
field := t.Field(i)

if !field.IsExported() {
continue
}

// Consider a field embedded if it's actually embedded
// or if it's tagged with `embed:""`.
_, isEmbedded := field.Tag.Lookup("embed")
isEmbedded = isEmbedded || field.Anonymous
if isEmbedded {
receivers = traverse(fieldValue, receivers)
}
}
}

return receivers
}

receivers := traverse(value, nil /* receivers */)

// Search all receivers for methods
var methods []reflect.Value
for _, receiver := range receivers {
Expand Down
19 changes: 19 additions & 0 deletions kong_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2405,6 +2405,8 @@ func TestProviderMethods(t *testing.T) {
}

type EmbeddedCallback struct {
Nested NestedCallback `embed:""`

Embedded bool
}

Expand All @@ -2414,6 +2416,8 @@ func (e *EmbeddedCallback) AfterApply() error {
}

type taggedEmbeddedCallback struct {
NestedCallback

Tagged bool
}

Expand All @@ -2422,6 +2426,15 @@ func (e *taggedEmbeddedCallback) AfterApply() error {
return nil
}

type NestedCallback struct {
nested bool
}

func (n *NestedCallback) AfterApply() error {
n.nested = true
return nil
}

type EmbeddedRoot struct {
EmbeddedCallback
Tagged taggedEmbeddedCallback `embed:""`
Expand All @@ -2441,9 +2454,15 @@ func TestEmbeddedCallbacks(t *testing.T) {
expected := &EmbeddedRoot{
EmbeddedCallback: EmbeddedCallback{
Embedded: true,
Nested: NestedCallback{
nested: true,
},
},
Tagged: taggedEmbeddedCallback{
Tagged: true,
NestedCallback: NestedCallback{
nested: true,
},
},
Root: true,
}
Expand Down