Skip to content

Commit 18fdf54

Browse files
cmd/protoc-gen-go-grpc: allow hooks to modify client structs and service handlers (#5240)
1 parent 337b815 commit 18fdf54

File tree

1 file changed

+68
-54
lines changed

1 file changed

+68
-54
lines changed

cmd/protoc-gen-go-grpc/grpc.go

Lines changed: 68 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,11 @@ const (
3636

3737
type serviceGenerateHelperInterface interface {
3838
formatFullMethodName(service *protogen.Service, method *protogen.Method) string
39+
generateClientStruct(g *protogen.GeneratedFile, clientName string)
3940
generateNewClientDefinitions(g *protogen.GeneratedFile, service *protogen.Service, clientName string)
4041
generateUnimplementedServerType(gen *protogen.Plugin, file *protogen.File, g *protogen.GeneratedFile, service *protogen.Service)
4142
generateServerFunctions(gen *protogen.Plugin, file *protogen.File, g *protogen.GeneratedFile, service *protogen.Service, serverType string, serviceDescVar string)
43+
formatHandlerFuncName(service *protogen.Service, hname string) string
4244
}
4345

4446
type serviceGenerateHelper struct{}
@@ -47,7 +49,15 @@ func (serviceGenerateHelper) formatFullMethodName(service *protogen.Service, met
4749
return fmt.Sprintf("/%s/%s", service.Desc.FullName(), method.Desc.Name())
4850
}
4951

52+
func (serviceGenerateHelper) generateClientStruct(g *protogen.GeneratedFile, clientName string) {
53+
g.P("type ", unexport(clientName), " struct {")
54+
g.P("cc ", grpcPackage.Ident("ClientConnInterface"))
55+
g.P("}")
56+
g.P()
57+
}
58+
5059
func (serviceGenerateHelper) generateNewClientDefinitions(g *protogen.GeneratedFile, service *protogen.Service, clientName string) {
60+
g.P("return &", unexport(clientName), "{cc}")
5161
}
5262

5363
func (serviceGenerateHelper) generateUnimplementedServerType(gen *protogen.Plugin, file *protogen.File, g *protogen.GeneratedFile, service *protogen.Service) {
@@ -77,6 +87,19 @@ func (serviceGenerateHelper) generateUnimplementedServerType(gen *protogen.Plugi
7787
}
7888

7989
func (serviceGenerateHelper) generateServerFunctions(gen *protogen.Plugin, file *protogen.File, g *protogen.GeneratedFile, service *protogen.Service, serverType string, serviceDescVar string) {
90+
// Server handler implementations.
91+
handlerNames := make([]string, 0, len(service.Methods))
92+
for _, method := range service.Methods {
93+
hname := genServerMethod(gen, file, g, method, func(hname string) string {
94+
return hname
95+
})
96+
handlerNames = append(handlerNames, hname)
97+
}
98+
genServiceDesc(file, g, serviceDescVar, serverType, service, handlerNames)
99+
}
100+
101+
func (serviceGenerateHelper) formatHandlerFuncName(service *protogen.Service, hname string) string {
102+
return hname
80103
}
81104

82105
var helper serviceGenerateHelperInterface = serviceGenerateHelper{}
@@ -158,18 +181,14 @@ func genService(gen *protogen.Plugin, file *protogen.File, g *protogen.Generated
158181
g.P()
159182

160183
// Client structure.
161-
g.P("type ", unexport(clientName), " struct {")
162-
g.P("cc ", grpcPackage.Ident("ClientConnInterface"))
163-
g.P("}")
164-
g.P()
184+
helper.generateClientStruct(g, clientName)
165185

166186
// NewClient factory.
167187
if service.Desc.Options().(*descriptorpb.ServiceOptions).GetDeprecated() {
168188
g.P(deprecationComment)
169189
}
170190
g.P("func New", clientName, " (cc ", grpcPackage.Ident("ClientConnInterface"), ") ", clientName, " {")
171191
helper.generateNewClientDefinitions(g, service, clientName)
172-
g.P("return &", unexport(clientName), "{cc}")
173192
g.P("}")
174193
g.P()
175194

@@ -239,52 +258,6 @@ func genService(gen *protogen.Plugin, file *protogen.File, g *protogen.Generated
239258
g.P()
240259

241260
helper.generateServerFunctions(gen, file, g, service, serverType, serviceDescVar)
242-
243-
// Server handler implementations.
244-
handlerNames := make([]string, 0, len(service.Methods))
245-
for _, method := range service.Methods {
246-
hname := genServerMethod(gen, file, g, method)
247-
handlerNames = append(handlerNames, hname)
248-
}
249-
250-
// Service descriptor.
251-
g.P("// ", serviceDescVar, " is the ", grpcPackage.Ident("ServiceDesc"), " for ", service.GoName, " service.")
252-
g.P("// It's only intended for direct use with ", grpcPackage.Ident("RegisterService"), ",")
253-
g.P("// and not to be introspected or modified (even as a copy)")
254-
g.P("var ", serviceDescVar, " = ", grpcPackage.Ident("ServiceDesc"), " {")
255-
g.P("ServiceName: ", strconv.Quote(string(service.Desc.FullName())), ",")
256-
g.P("HandlerType: (*", serverType, ")(nil),")
257-
g.P("Methods: []", grpcPackage.Ident("MethodDesc"), "{")
258-
for i, method := range service.Methods {
259-
if method.Desc.IsStreamingClient() || method.Desc.IsStreamingServer() {
260-
continue
261-
}
262-
g.P("{")
263-
g.P("MethodName: ", strconv.Quote(string(method.Desc.Name())), ",")
264-
g.P("Handler: ", handlerNames[i], ",")
265-
g.P("},")
266-
}
267-
g.P("},")
268-
g.P("Streams: []", grpcPackage.Ident("StreamDesc"), "{")
269-
for i, method := range service.Methods {
270-
if !method.Desc.IsStreamingClient() && !method.Desc.IsStreamingServer() {
271-
continue
272-
}
273-
g.P("{")
274-
g.P("StreamName: ", strconv.Quote(string(method.Desc.Name())), ",")
275-
g.P("Handler: ", handlerNames[i], ",")
276-
if method.Desc.IsStreamingServer() {
277-
g.P("ServerStreams: true,")
278-
}
279-
if method.Desc.IsStreamingClient() {
280-
g.P("ClientStreams: true,")
281-
}
282-
g.P("},")
283-
}
284-
g.P("},")
285-
g.P("Metadata: \"", file.Desc.Path(), "\",")
286-
g.P("}")
287-
g.P()
288261
}
289262

290263
func clientSignature(g *protogen.GeneratedFile, method *protogen.Method) string {
@@ -397,12 +370,53 @@ func serverSignature(g *protogen.GeneratedFile, method *protogen.Method) string
397370
return method.GoName + "(" + strings.Join(reqArgs, ", ") + ") " + ret
398371
}
399372

400-
func genServerMethod(gen *protogen.Plugin, file *protogen.File, g *protogen.GeneratedFile, method *protogen.Method) string {
373+
func genServiceDesc(file *protogen.File, g *protogen.GeneratedFile, serviceDescVar string, serverType string, service *protogen.Service, handlerNames []string) {
374+
// Service descriptor.
375+
g.P("// ", serviceDescVar, " is the ", grpcPackage.Ident("ServiceDesc"), " for ", service.GoName, " service.")
376+
g.P("// It's only intended for direct use with ", grpcPackage.Ident("RegisterService"), ",")
377+
g.P("// and not to be introspected or modified (even as a copy)")
378+
g.P("var ", serviceDescVar, " = ", grpcPackage.Ident("ServiceDesc"), " {")
379+
g.P("ServiceName: ", strconv.Quote(string(service.Desc.FullName())), ",")
380+
g.P("HandlerType: (*", serverType, ")(nil),")
381+
g.P("Methods: []", grpcPackage.Ident("MethodDesc"), "{")
382+
for i, method := range service.Methods {
383+
if method.Desc.IsStreamingClient() || method.Desc.IsStreamingServer() {
384+
continue
385+
}
386+
g.P("{")
387+
g.P("MethodName: ", strconv.Quote(string(method.Desc.Name())), ",")
388+
g.P("Handler: ", handlerNames[i], ",")
389+
g.P("},")
390+
}
391+
g.P("},")
392+
g.P("Streams: []", grpcPackage.Ident("StreamDesc"), "{")
393+
for i, method := range service.Methods {
394+
if !method.Desc.IsStreamingClient() && !method.Desc.IsStreamingServer() {
395+
continue
396+
}
397+
g.P("{")
398+
g.P("StreamName: ", strconv.Quote(string(method.Desc.Name())), ",")
399+
g.P("Handler: ", handlerNames[i], ",")
400+
if method.Desc.IsStreamingServer() {
401+
g.P("ServerStreams: true,")
402+
}
403+
if method.Desc.IsStreamingClient() {
404+
g.P("ClientStreams: true,")
405+
}
406+
g.P("},")
407+
}
408+
g.P("},")
409+
g.P("Metadata: \"", file.Desc.Path(), "\",")
410+
g.P("}")
411+
g.P()
412+
}
413+
414+
func genServerMethod(gen *protogen.Plugin, file *protogen.File, g *protogen.GeneratedFile, method *protogen.Method, hnameFuncNameFormatter func(string) string) string {
401415
service := method.Parent
402416
hname := fmt.Sprintf("_%s_%s_Handler", service.GoName, method.GoName)
403417

404418
if !method.Desc.IsStreamingClient() && !method.Desc.IsStreamingServer() {
405-
g.P("func ", hname, "(srv interface{}, ctx ", contextPackage.Ident("Context"), ", dec func(interface{}) error, interceptor ", grpcPackage.Ident("UnaryServerInterceptor"), ") (interface{}, error) {")
419+
g.P("func ", hnameFuncNameFormatter(hname), "(srv interface{}, ctx ", contextPackage.Ident("Context"), ", dec func(interface{}) error, interceptor ", grpcPackage.Ident("UnaryServerInterceptor"), ") (interface{}, error) {")
406420
g.P("in := new(", method.Input.GoIdent, ")")
407421
g.P("if err := dec(in); err != nil { return nil, err }")
408422
g.P("if interceptor == nil { return srv.(", service.GoName, "Server).", method.GoName, "(ctx, in) }")
@@ -420,7 +434,7 @@ func genServerMethod(gen *protogen.Plugin, file *protogen.File, g *protogen.Gene
420434
return hname
421435
}
422436
streamType := unexport(service.GoName) + method.GoName + "Server"
423-
g.P("func ", hname, "(srv interface{}, stream ", grpcPackage.Ident("ServerStream"), ") error {")
437+
g.P("func ", hnameFuncNameFormatter(hname), "(srv interface{}, stream ", grpcPackage.Ident("ServerStream"), ") error {")
424438
if !method.Desc.IsStreamingClient() {
425439
g.P("m := new(", method.Input.GoIdent, ")")
426440
g.P("if err := stream.RecvMsg(m); err != nil { return err }")

0 commit comments

Comments
 (0)