Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 13 additions & 5 deletions request.go
Original file line number Diff line number Diff line change
Expand Up @@ -201,17 +201,25 @@ func isSafeInHTTPQuotedString(text string) bool { // nolint: gocyclo
// the query parameters, and the JSON content. In particular the version of
// HTTP and the headers aren't protected by the signature.
func VerifyHTTPRequest(
req *http.Request, now time.Time, destination ServerName, keys JSONVerifier,
req *http.Request, now time.Time,
destination ServerName, // the default server name, if none other is given
isLocalServerName func(ServerName) bool, // optional, verify secondary server names
keys JSONVerifier,
) (*FederationRequest, util.JSONResponse) {
request, err := readHTTPRequest(req)
if err != nil {
util.GetLogger(req.Context()).WithError(err).Print("Error parsing HTTP headers")
return nil, util.MessageResponse(400, "Bad Request")
}
if request.fields.Destination != "" && request.fields.Destination != destination {
message := "Unrecognised server name for Destination"
util.GetLogger(req.Context()).WithError(err).Print(message)
return nil, util.MessageResponse(400, message)
if request.fields.Destination != "" {
switch {
case isLocalServerName != nil && !isLocalServerName(request.fields.Destination):
fallthrough
case isLocalServerName == nil && destination != request.fields.Destination:
message := fmt.Sprintf("Unrecognised server name %q for Destination", request.fields.Destination)
util.GetLogger(req.Context()).Warn(message)
return nil, util.MessageResponse(400, message)
}
} else if request.fields.Destination == "" {
request.fields.Destination = destination
}
Expand Down
4 changes: 2 additions & 2 deletions request_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ func TestVerifyGetRequest(t *testing.T) {
t.Fatal(err)
}
request, jsonResp := VerifyHTTPRequest(
hr, time.Unix(1493142432, 96400), "localhost:44033", KeyRing{nil, &testKeyDatabase{}},
hr, time.Unix(1493142432, 96400), "localhost:44033", nil, KeyRing{nil, &testKeyDatabase{}},
)
if request == nil {
t.Fatalf("Wanted non-nil request got nil. (request was %#v, response was %#v)", hr, jsonResp)
Expand Down Expand Up @@ -137,7 +137,7 @@ func TestVerifyPutRequest(t *testing.T) {
t.Fatal(err)
}
request, jsonResp := VerifyHTTPRequest(
hr, time.Unix(1493142432, 96400), "localhost:44033", KeyRing{nil, &testKeyDatabase{}},
hr, time.Unix(1493142432, 96400), "localhost:44033", nil, KeyRing{nil, &testKeyDatabase{}},
)
if request == nil {
t.Fatalf("Wanted non-nil request got nil. (request was %#v, response was %#v)", hr, jsonResp)
Expand Down