-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathsse.go
More file actions
281 lines (246 loc) · 8.58 KB
/
sse.go
File metadata and controls
281 lines (246 loc) · 8.58 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
package prefab
import (
"context"
"fmt"
"io"
"net/http"
"regexp"
"strings"
"github.com/dpup/prefab/errors"
"github.com/dpup/prefab/logging"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/protobuf/encoding/protojson"
"google.golang.org/protobuf/proto"
)
// ClientStream represents a gRPC client stream that can receive messages.
// This interface is satisfied by all generated gRPC client stream types.
type ClientStream[T proto.Message] interface {
Recv() (T, error)
grpc.ClientStream
}
// SSEStreamStarter is a function that starts a gRPC client stream.
// It receives the request context, path/query parameters, and a gRPC client connection.
// It should create a client and call the streaming method, returning the stream.
//
// Example:
//
// func(ctx context.Context, params map[string]string, cc grpc.ClientConnInterface) (NotesStreamService_StreamUpdatesClient, error) {
// client := NewNotesStreamServiceClient(cc)
// return client.StreamUpdates(ctx, &StreamRequest{NoteId: params["id"]})
// }
type SSEStreamStarter[T proto.Message] func(ctx context.Context, params map[string]string, cc grpc.ClientConnInterface) (ClientStream[T], error)
// pathPattern represents a parsed path pattern with parameter extraction.
type pathPattern struct {
pattern *regexp.Regexp
params []string
prefix string
}
// parsePathPattern converts a path pattern like "/notes/{id}/updates" into a regex
// that can match requests and extract parameters.
func parsePathPattern(pattern string) (*pathPattern, error) {
if pattern == "" {
return nil, errors.NewC("sse: path pattern cannot be empty", codes.InvalidArgument)
}
// Extract parameter names and build regex
var params []string
var regexPattern strings.Builder
regexPattern.WriteString("^")
// Find the prefix (everything before the first parameter)
prefix := pattern
if idx := strings.Index(pattern, "{"); idx != -1 {
prefix = pattern[:idx]
}
parts := strings.Split(pattern, "/")
for i, part := range parts {
if part == "" {
continue
}
if i > 0 {
regexPattern.WriteString("/")
}
if strings.HasPrefix(part, "{") && strings.HasSuffix(part, "}") {
// Extract parameter name
paramName := part[1 : len(part)-1]
if paramName == "" {
return nil, errors.NewC("sse: empty parameter name in pattern", codes.InvalidArgument)
}
params = append(params, paramName)
// Match any non-slash characters
regexPattern.WriteString("([^/]+)")
} else {
// Literal path component
regexPattern.WriteString(regexp.QuoteMeta(part))
}
}
regexPattern.WriteString("$")
re, err := regexp.Compile(regexPattern.String())
if err != nil {
return nil, errors.WrapPrefix(err, "sse: invalid path pattern", 0)
}
return &pathPattern{
pattern: re,
params: params,
prefix: prefix,
}, nil
}
// extractParams extracts parameter values from a request path.
func (p *pathPattern) extractParams(path string) (map[string]string, bool) {
matches := p.pattern.FindStringSubmatch(path)
if matches == nil {
return nil, false
}
params := make(map[string]string)
for i, name := range p.params {
params[name] = matches[i+1]
}
return params, true
}
// createSSEHandler creates an HTTP handler that serves Server-Sent Events from a gRPC stream.
func createSSEHandler[T proto.Message](pattern *pathPattern, starter SSEStreamStarter[T], s *Server) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
// Only allow GET requests
if r.Method != http.MethodGet {
logging.Warnf(ctx, "sse: invalid method %s for path %s", r.Method, r.URL.Path)
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
// Extract path parameters
params, ok := pattern.extractParams(r.URL.Path)
if !ok {
logging.Errorw(ctx, "sse: path does not match pattern", "path", r.URL.Path)
http.Error(w, "Not found", http.StatusNotFound)
return
}
// Add query parameters to params map
for key, values := range r.URL.Query() {
if len(values) > 0 {
params["query."+key] = values[0]
}
}
// Set SSE headers
w.Header().Set("Content-Type", "text/event-stream")
w.Header().Set("Cache-Control", "no-cache")
w.Header().Set("Connection", "keep-alive")
w.Header().Set("X-Accel-Buffering", "no") // Disable nginx buffering
// Check if the ResponseWriter supports flushing
flusher, ok := w.(http.Flusher)
if !ok {
logging.Error(ctx, "sse: streaming not supported")
http.Error(w, "Streaming not supported", http.StatusInternalServerError)
return
}
// Create a context that will be cancelled when the client disconnects
ctx, cancel := context.WithCancel(ctx)
defer cancel()
// Use the shared gRPC client connection
cc := s.sseClientConn
// Start the gRPC stream
stream, err := starter(ctx, params, cc)
if err != nil {
logging.Errorw(ctx, "sse: failed to start stream", "error", err)
http.Error(w, fmt.Sprintf("Failed to start stream: %v", err), http.StatusInternalServerError)
return
}
logging.Infow(ctx, "sse: client connected", "path", r.URL.Path, "params", params)
streamMessages(ctx, stream, r, w, flusher)
})
}
func streamMessages[T proto.Message](ctx context.Context, stream ClientStream[T], r *http.Request, w http.ResponseWriter, flusher http.Flusher) {
// Marshal options for JSON conversion
marshaler := protojson.MarshalOptions{
EmitUnpopulated: true,
UseProtoNames: false,
}
for {
msg, err := stream.Recv()
if errors.Is(err, io.EOF) {
// Stream completed normally
logging.Infow(ctx, "sse: stream completed", "path", r.URL.Path)
return
}
if err != nil {
logging.Errorw(ctx, "sse: stream error", "error", err)
// Send error as SSE comment (not visible to EventSource API but visible in raw stream)
fmt.Fprintf(w, ": error: %s\n\n", err.Error())
flusher.Flush()
return
}
// Convert proto message to JSON
data, err := marshaler.Marshal(msg)
if err != nil {
logging.Errorw(ctx, "sse: failed to marshal message", "error", err)
continue
}
// Write SSE event
if _, err := fmt.Fprintf(w, "data: %s\n\n", data); err != nil {
logging.Errorw(ctx, "sse: failed to write event", "error", err)
return
}
// Flush the data immediately
flusher.Flush()
}
}
// WithSSEStream registers a Server-Sent Events endpoint that streams from a gRPC streaming method.
//
// The path can include parameters in curly braces, e.g., "/notes/{id}/updates".
// These parameters will be extracted and passed to the stream starter function.
//
// The starter function receives:
// - ctx: Request context (cancelled when client disconnects)
// - params: Map of path and query parameters
// - cc: gRPC client connection (connected to this server)
//
// The starter function should create a gRPC client and call the streaming method.
//
// Example:
//
// server := prefab.New(
// prefab.WithSSEStream(
// "/notes/{id}/updates",
// func(ctx context.Context, params map[string]string, cc grpc.ClientConnInterface) (NotesStreamService_StreamUpdatesClient, error) {
// client := NewNotesStreamServiceClient(cc)
// return client.StreamUpdates(ctx, &StreamRequest{NoteId: params["id"]})
// },
// ),
// )
//
// All stream management (reading, cancellation, error handling, SSE formatting) is handled automatically.
//
// Multiple SSE endpoints share a single gRPC client connection for efficiency.
func WithSSEStream[T proto.Message](path string, starter SSEStreamStarter[T]) ServerOption {
return func(b *builder) {
pattern, err := parsePathPattern(path)
if err != nil {
panic(err)
}
// Capture the server reference to access the shared connection
var server *Server
// Register a server builder that:
// 1. Creates the shared SSE client connection if not already created
// 2. Stores the server reference for handlers
b.serverBuilders = append(b.serverBuilders, func(s *Server) {
server = s
// Create the shared SSE client connection if this is the first SSE endpoint
if s.sseClientConn == nil {
_, _, endpoint, opts := s.GatewayArgs()
conn, err := grpc.NewClient(endpoint, opts...)
if err != nil {
panic(fmt.Sprintf("sse: failed to create shared client connection: %v", err))
}
s.sseClientConn = conn
logging.Infow(s.baseContext, "sse: created shared gRPC client connection", "endpoint", endpoint)
}
})
// Register the HTTP handler
b.handlers = append(b.handlers, handler{
prefix: pattern.prefix,
httpHandler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Use the server's shared connection
h := createSSEHandler(pattern, starter, server)
h.ServeHTTP(w, r)
}),
})
}
}