Skip to content

Implement client ID as config option #449

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Aug 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/stackit_config_set.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ stackit config set [flags]
--dns-custom-endpoint string DNS API base URL, used in calls to this API
-h, --help Help for "stackit config set"
--iaas-custom-endpoint string IaaS API base URL, used in calls to this API
--identity-provider-custom-client-id string Identity Provider client ID, used for user authentication
--identity-provider-custom-endpoint string Identity Provider base URL, used for user authentication
--load-balancer-custom-endpoint string Load Balancer API base URL, used in calls to this API
--logme-custom-endpoint string LogMe API base URL, used in calls to this API
Expand Down
1 change: 1 addition & 0 deletions docs/stackit_config_unset.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ stackit config unset [flags]
--dns-custom-endpoint DNS API base URL. If unset, uses the default base URL
-h, --help Help for "stackit config unset"
--iaas-custom-endpoint IaaS API base URL. If unset, uses the default base URL
--identity-provider-custom-client-id Identity Provider client ID, used for user authentication
--identity-provider-custom-endpoint Identity Provider base URL. If unset, uses the default base URL
--load-balancer-custom-endpoint Load Balancer API base URL. If unset, uses the default base URL
--logme-custom-endpoint LogMe API base URL. If unset, uses the default base URL
Expand Down
4 changes: 4 additions & 0 deletions internal/cmd/config/set/set.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
const (
sessionTimeLimitFlag = "session-time-limit"
identityProviderCustomEndpointFlag = "identity-provider-custom-endpoint"
identityProviderCustomClientIdFlag = "identity-provider-custom-client-id"

argusCustomEndpointFlag = "argus-custom-endpoint"
authorizationCustomEndpointFlag = "authorization-custom-endpoint"
Expand Down Expand Up @@ -129,6 +130,7 @@ Use "{{.CommandPath}} [command] --help" for more information about a command.{{e
func configureFlags(cmd *cobra.Command) {
cmd.Flags().String(sessionTimeLimitFlag, "", "Maximum time before authentication is required again. After this time, you will be prompted to login again to execute commands that require authentication. Can't be larger than 24h. Requires authentication after being set to take effect. Examples: 3h, 5h30m40s (BETA: currently values greater than 2h have no effect)")
cmd.Flags().String(identityProviderCustomEndpointFlag, "", "Identity Provider base URL, used for user authentication")
cmd.Flags().String(identityProviderCustomClientIdFlag, "", "Identity Provider client ID, used for user authentication")
cmd.Flags().String(argusCustomEndpointFlag, "", "Argus API base URL, used in calls to this API")
cmd.Flags().String(authorizationCustomEndpointFlag, "", "Authorization API base URL, used in calls to this API")
cmd.Flags().String(dnsCustomEndpointFlag, "", "DNS API base URL, used in calls to this API")
Expand All @@ -155,6 +157,8 @@ func configureFlags(cmd *cobra.Command) {
cobra.CheckErr(err)
err = viper.BindPFlag(config.IdentityProviderCustomEndpointKey, cmd.Flags().Lookup(identityProviderCustomEndpointFlag))
cobra.CheckErr(err)
err = viper.BindPFlag(config.IdentityProviderCustomClientIdKey, cmd.Flags().Lookup(identityProviderCustomClientIdFlag))
cobra.CheckErr(err)

err = viper.BindPFlag(config.ArgusCustomEndpointKey, cmd.Flags().Lookup(argusCustomEndpointFlag))
cobra.CheckErr(err)
Expand Down
7 changes: 7 additions & 0 deletions internal/cmd/config/unset/unset.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ const (

sessionTimeLimitFlag = "session-time-limit"
identityProviderCustomEndpointFlag = "identity-provider-custom-endpoint"
identityProviderCustomClientIdFlag = "identity-provider-custom-client-id"

argusCustomEndpointFlag = "argus-custom-endpoint"
authorizationCustomEndpointFlag = "authorization-custom-endpoint"
Expand Down Expand Up @@ -54,6 +55,7 @@ type inputModel struct {

SessionTimeLimit bool
IdentityProviderCustomEndpoint bool
IdentityProviderCustomClientID bool

ArgusCustomEndpoint bool
AuthorizationCustomEndpoint bool
Expand Down Expand Up @@ -117,6 +119,9 @@ func NewCmd(p *print.Printer) *cobra.Command {
if model.IdentityProviderCustomEndpoint {
viper.Set(config.IdentityProviderCustomEndpointKey, "")
}
if model.IdentityProviderCustomClientID {
viper.Set(config.IdentityProviderCustomClientIdKey, "")
}

if model.ArgusCustomEndpoint {
viper.Set(config.ArgusCustomEndpointKey, "")
Expand Down Expand Up @@ -201,6 +206,7 @@ func configureFlags(cmd *cobra.Command) {

cmd.Flags().Bool(sessionTimeLimitFlag, false, fmt.Sprintf("Maximum time before authentication is required again. If unset, defaults to %s", config.SessionTimeLimitDefault))
cmd.Flags().Bool(identityProviderCustomEndpointFlag, false, "Identity Provider base URL. If unset, uses the default base URL")
cmd.Flags().Bool(identityProviderCustomClientIdFlag, false, "Identity Provider client ID, used for user authentication")

cmd.Flags().Bool(argusCustomEndpointFlag, false, "Argus API base URL. If unset, uses the default base URL")
cmd.Flags().Bool(authorizationCustomEndpointFlag, false, "Authorization API base URL. If unset, uses the default base URL")
Expand Down Expand Up @@ -234,6 +240,7 @@ func parseInput(p *print.Printer, cmd *cobra.Command) *inputModel {

SessionTimeLimit: flags.FlagToBoolValue(p, cmd, sessionTimeLimitFlag),
IdentityProviderCustomEndpoint: flags.FlagToBoolValue(p, cmd, identityProviderCustomEndpointFlag),
IdentityProviderCustomClientID: flags.FlagToBoolValue(p, cmd, identityProviderCustomClientIdFlag),

ArgusCustomEndpoint: flags.FlagToBoolValue(p, cmd, argusCustomEndpointFlag),
AuthorizationCustomEndpoint: flags.FlagToBoolValue(p, cmd, authorizationCustomEndpointFlag),
Expand Down
13 changes: 13 additions & 0 deletions internal/cmd/config/unset/unset_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ func fixtureFlagValues(mods ...func(flagValues map[string]bool)) map[string]bool

sessionTimeLimitFlag: true,
identityProviderCustomEndpointFlag: true,
identityProviderCustomClientIdFlag: true,

argusCustomEndpointFlag: true,
authorizationCustomEndpointFlag: true,
Expand Down Expand Up @@ -53,6 +54,7 @@ func fixtureInputModel(mods ...func(model *inputModel)) *inputModel {

SessionTimeLimit: true,
IdentityProviderCustomEndpoint: true,
IdentityProviderCustomClientID: true,

ArgusCustomEndpoint: true,
AuthorizationCustomEndpoint: true,
Expand Down Expand Up @@ -104,6 +106,7 @@ func TestParseInput(t *testing.T) {

model.SessionTimeLimit = false
model.IdentityProviderCustomEndpoint = false
model.IdentityProviderCustomClientID = false

model.ArgusCustomEndpoint = false
model.AuthorizationCustomEndpoint = false
Expand Down Expand Up @@ -155,6 +158,16 @@ func TestParseInput(t *testing.T) {
model.IdentityProviderCustomEndpoint = false
}),
},
{
description: "identity provider custom client id empty",
flagValues: fixtureFlagValues(func(flagValues map[string]bool) {
flagValues[identityProviderCustomClientIdFlag] = false
}),
isValid: true,
expectedModel: fixtureInputModel(func(model *inputModel) {
model.IdentityProviderCustomClientID = false
}),
},
{
description: "argus custom endpoint empty",
flagValues: fixtureFlagValues(func(flagValues map[string]bool) {
Expand Down
19 changes: 16 additions & 3 deletions internal/pkg/auth/user_login.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import (

const (
defaultIDPEndpoint = "https://accounts.stackit.cloud/oauth/v2"
cliClientID = "stackit-cli-0000-0000-000000000001"
defaultCLIClientID = "stackit-cli-0000-0000-000000000001"

loginSuccessPath = "/login-successful"
stackitLandingPage = "https://www.stackit.de"
Expand Down Expand Up @@ -58,6 +58,18 @@ func AuthorizeUser(p *print.Printer, isReauthentication bool) error {
}
}

idpClientID, err := getIDPClientID()
if err != nil {
return err
}
if idpClientID != defaultCLIClientID {
p.Warn("You are using a custom client ID (%s) for authentication.\n", idpClientID)
err := p.PromptForEnter("Press Enter to proceed with the login...")
if err != nil {
return err
}
}

if isReauthentication {
err := p.PromptForEnter("Your session has expired, press Enter to login again...")
if err != nil {
Expand Down Expand Up @@ -86,7 +98,7 @@ func AuthorizeUser(p *print.Printer, isReauthentication bool) error {
}

conf := &oauth2.Config{
ClientID: cliClientID,
ClientID: idpClientID,
Endpoint: oauth2.Endpoint{
AuthURL: fmt.Sprintf("%s/authorize", idpEndpoint),
},
Expand Down Expand Up @@ -131,7 +143,7 @@ func AuthorizeUser(p *print.Printer, isReauthentication bool) error {
p.Debug(print.DebugLevel, "trading authorization code for access and refresh tokens")

// Trade the authorization code and the code verifier for access and refresh tokens
accessToken, refreshToken, err := getUserAccessAndRefreshTokens(idpEndpoint, cliClientID, codeVerifier, code, redirectURL)
accessToken, refreshToken, err := getUserAccessAndRefreshTokens(idpEndpoint, idpClientID, codeVerifier, code, redirectURL)
if err != nil {
errServer = fmt.Errorf("retrieve tokens: %w", err)
return
Expand Down Expand Up @@ -207,6 +219,7 @@ func AuthorizeUser(p *print.Printer, isReauthentication bool) error {

p.Debug(print.DebugLevel, "opening browser for authentication")
p.Debug(print.DebugLevel, "using authentication server on %s", idpEndpoint)
p.Debug(print.DebugLevel, "using client ID %s for authentication ", idpClientID)

// Open a browser window to the authorizationURL
err = openBrowser(authorizationURL)
Expand Down
7 changes: 6 additions & 1 deletion internal/pkg/auth/user_token_flow.go
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,11 @@ func buildRequestToRefreshTokens(utf *userTokenFlow) (*http.Request, error) {
return nil, err
}

idpClientID, err := getIDPClientID()
if err != nil {
return nil, err
}

req, err := http.NewRequest(
http.MethodPost,
fmt.Sprintf("%s/token", idpEndpoint),
Expand All @@ -171,7 +176,7 @@ func buildRequestToRefreshTokens(utf *userTokenFlow) (*http.Request, error) {
}
reqQuery := url.Values{}
reqQuery.Set("grant_type", "refresh_token")
reqQuery.Set("client_id", cliClientID)
reqQuery.Set("client_id", idpClientID)
reqQuery.Set("refresh_token", utf.refreshToken)
reqQuery.Set("token_format", "jwt")
req.URL.RawQuery = reqQuery.Encode()
Expand Down
11 changes: 11 additions & 0 deletions internal/pkg/auth/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,14 @@ func getIDPEndpoint() (string, error) {

return idpEndpoint, nil
}

func getIDPClientID() (string, error) {
idpClientID := defaultCLIClientID

customIDPClientID := viper.GetString(config.IdentityProviderCustomClientIdKey)
if customIDPClientID != "" {
idpClientID = customIDPClientID
}

return idpClientID, nil
}
41 changes: 41 additions & 0 deletions internal/pkg/auth/utils_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,44 @@ func TestGetIDPEndpoint(t *testing.T) {
})
}
}

func TestGetIDPClientID(t *testing.T) {
tests := []struct {
name string
idpCustomClientID string
isValid bool
expected string
}{
{
name: "custom client ID specified",
idpCustomClientID: "custom-client-id",
isValid: true,
expected: "custom-client-id",
},
{
name: "custom client ID not specified",
idpCustomClientID: "",
isValid: true,
expected: defaultCLIClientID,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
viper.Reset()
viper.Set(config.IdentityProviderCustomClientIdKey, tt.idpCustomClientID)

got, err := getIDPClientID()

if tt.isValid && err != nil {
t.Fatalf("expected no error, got %v", err)
}
if !tt.isValid && err == nil {
t.Fatalf("expected error, got none")
}

if got != tt.expected {
t.Fatalf("expected idp client ID %q, got %q", tt.expected, got)
}
})
}
}
3 changes: 3 additions & 0 deletions internal/pkg/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ const (
VerbosityKey = "verbosity"

IdentityProviderCustomEndpointKey = "identity_provider_custom_endpoint"
IdentityProviderCustomClientIdKey = "identity_provider_custom_client_id"

ArgusCustomEndpointKey = "argus_custom_endpoint"
AuthorizationCustomEndpointKey = "authorization_custom_endpoint"
Expand Down Expand Up @@ -67,6 +68,7 @@ var ConfigKeys = []string{
VerbosityKey,

IdentityProviderCustomEndpointKey,
IdentityProviderCustomClientIdKey,

DNSCustomEndpointKey,
LoadBalancerCustomEndpointKey,
Expand Down Expand Up @@ -148,6 +150,7 @@ func setConfigDefaults() {
viper.SetDefault(ProjectIdKey, "")
viper.SetDefault(SessionTimeLimitKey, SessionTimeLimitDefault)
viper.SetDefault(IdentityProviderCustomEndpointKey, "")
viper.SetDefault(IdentityProviderCustomClientIdKey, "")
viper.SetDefault(DNSCustomEndpointKey, "")
viper.SetDefault(ArgusCustomEndpointKey, "")
viper.SetDefault(AuthorizationCustomEndpointKey, "")
Expand Down