Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions v2/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
- Add tasks endpoints to v2
- Add missing endpoints from collections to v2
- Add missing endpoints from query to v2
- Add SSO auth token implementation

## [2.1.3](https://github.com/arangodb/go-driver/tree/v2.1.3) (2025-02-21)
- Switch to Go 1.22.11
Expand Down
96 changes: 89 additions & 7 deletions v2/connection/auth_jwt_impl.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,19 @@ package connection

import (
"context"
"encoding/base64"
"encoding/json"
"fmt"
"net/http"
"strings"
"time"
)

func NewJWTAuthWrapper(username, password string) Wrapper {
return WrapAuthentication(func(ctx context.Context, conn Connection) (authentication Authentication, err error) {
var token string
var expiry time.Time

refresh := func(ctx context.Context, conn Connection) error {
url := NewUrl("_open", "auth")

var data jwtOpenResponse
Expand All @@ -40,15 +48,26 @@ func NewJWTAuthWrapper(username, password string) Wrapper {

resp, err := CallPost(ctx, conn, url, &data, j)
if err != nil {
return nil, err
return err
}
if resp.Code() != http.StatusOK {
return NewError(resp.Code(), "unexpected code")
}

switch resp.Code() {
case http.StatusOK:
return NewHeaderAuth("Authorization", "bearer %s", data.Token), nil
default:
return nil, NewError(resp.Code(), "unexpected code")
token = data.Token
expiry, _ = parseJWTExpiry(token) // ignore error, just fallback to immediate refresh next time
return nil
}

return WrapAuthentication(func(ctx context.Context, conn Connection) (Authentication, error) {
// First time fetch
if token == "" || time.Now().After(expiry) {
if err := refresh(ctx, conn); err != nil {
return nil, err
}
}

return NewHeaderAuth("Authorization", "bearer %s", token), nil
})
}

Expand All @@ -59,5 +78,68 @@ type jwtOpenRequest struct {

type jwtOpenResponse struct {
Token string `json:"jwt"`
ExpiresIn int `json:"expires_in,omitempty"`
MustChangePassword bool `json:"must_change_password,omitempty"`
}

func parseJWTExpiry(token string) (time.Time, error) {
parts := strings.Split(token, ".")
if len(parts) < 2 {
return time.Time{}, fmt.Errorf("invalid JWT format")
}

payload, err := base64.RawURLEncoding.DecodeString(parts[1])
if err != nil {
return time.Time{}, err
}

var claims struct {
Exp int64 `json:"exp"`
}
if err := json.Unmarshal(payload, &claims); err != nil {
return time.Time{}, err
}

return time.Unix(claims.Exp, 0), nil
}

func NewSSOAuthWrapper(initialToken string) Wrapper {
var token = initialToken
var expiry time.Time

setToken := func(newToken string) {
token = newToken
expiry, _ = parseJWTExpiry(newToken)
}

// If we already have a token (from an SSO login), parse expiry now
if token != "" {
setToken(token)
}

return WrapAuthentication(func(ctx context.Context, conn Connection) (Authentication, error) {
// No token yet or expired — let caller know they must login via SSO
if token == "" || time.Now().After(expiry) {
// Try a call to _open/auth just to see if server sends 307
url := NewUrl("_open", "auth")
var data jwtOpenResponse

resp, err := CallPost(ctx, conn, url, &data, nil)
if err != nil {
return nil, err
}

switch resp.Code() {
case http.StatusOK:
setToken(data.Token)
case http.StatusTemporaryRedirect:
loc := resp.Header("Location")
return nil, fmt.Errorf("SSO redirect: please authenticate via browser at %s", loc)
default:
return nil, NewError(resp.Code(), "unexpected code")
}
}

return NewHeaderAuth("Authorization", "bearer %s", token), nil
})
}