From 65c47e48a50e3ca0a48e441387233348b6e4765d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Go=CC=88kc=CC=A7e=20Go=CC=88k=20Klingel?= Date: Thu, 15 Aug 2024 13:24:45 +0200 Subject: [PATCH 1/2] implement client ID as config option --- docs/stackit_config_set.md | 1 + docs/stackit_config_unset.md | 1 + internal/cmd/config/set/set.go | 4 +++ internal/cmd/config/unset/unset.go | 7 +++++ internal/cmd/config/unset/unset_test.go | 13 ++++++++ internal/pkg/auth/user_login.go | 21 +++++++++++-- internal/pkg/auth/user_token_flow.go | 7 ++++- internal/pkg/auth/utils.go | 11 +++++++ internal/pkg/auth/utils_test.go | 41 +++++++++++++++++++++++++ internal/pkg/config/config.go | 3 ++ 10 files changed, 105 insertions(+), 4 deletions(-) diff --git a/docs/stackit_config_set.md b/docs/stackit_config_set.md index 0ca92d5a9..be39928f8 100644 --- a/docs/stackit_config_set.md +++ b/docs/stackit_config_set.md @@ -34,6 +34,7 @@ stackit config set [flags] --dns-custom-endpoint string DNS API base URL, used in calls to this API -h, --help Help for "stackit config set" --iaas-custom-endpoint string IaaS API base URL, used in calls to this API + --identity-provider-custom-client-id string Identity Provider client ID, used for user authentication --identity-provider-custom-endpoint string Identity Provider base URL, used for user authentication --load-balancer-custom-endpoint string Load Balancer API base URL, used in calls to this API --logme-custom-endpoint string LogMe API base URL, used in calls to this API diff --git a/docs/stackit_config_unset.md b/docs/stackit_config_unset.md index 0a05047a3..11c9c459a 100644 --- a/docs/stackit_config_unset.md +++ b/docs/stackit_config_unset.md @@ -32,6 +32,7 @@ stackit config unset [flags] --dns-custom-endpoint DNS API base URL. If unset, uses the default base URL -h, --help Help for "stackit config unset" --iaas-custom-endpoint IaaS API base URL. If unset, uses the default base URL + --identity-provider-custom-client-id Identity Provider client ID, used for user authentication --identity-provider-custom-endpoint Identity Provider base URL. If unset, uses the default base URL --load-balancer-custom-endpoint Load Balancer API base URL. If unset, uses the default base URL --logme-custom-endpoint LogMe API base URL. If unset, uses the default base URL diff --git a/internal/cmd/config/set/set.go b/internal/cmd/config/set/set.go index e0b9ada9e..e44372455 100644 --- a/internal/cmd/config/set/set.go +++ b/internal/cmd/config/set/set.go @@ -19,6 +19,7 @@ import ( const ( sessionTimeLimitFlag = "session-time-limit" identityProviderCustomEndpointFlag = "identity-provider-custom-endpoint" + identityProviderCustomClientIdFlag = "identity-provider-custom-client-id" argusCustomEndpointFlag = "argus-custom-endpoint" authorizationCustomEndpointFlag = "authorization-custom-endpoint" @@ -129,6 +130,7 @@ Use "{{.CommandPath}} [command] --help" for more information about a command.{{e func configureFlags(cmd *cobra.Command) { cmd.Flags().String(sessionTimeLimitFlag, "", "Maximum time before authentication is required again. After this time, you will be prompted to login again to execute commands that require authentication. Can't be larger than 24h. Requires authentication after being set to take effect. Examples: 3h, 5h30m40s (BETA: currently values greater than 2h have no effect)") cmd.Flags().String(identityProviderCustomEndpointFlag, "", "Identity Provider base URL, used for user authentication") + cmd.Flags().String(identityProviderCustomClientIdFlag, "", "Identity Provider client ID, used for user authentication") cmd.Flags().String(argusCustomEndpointFlag, "", "Argus API base URL, used in calls to this API") cmd.Flags().String(authorizationCustomEndpointFlag, "", "Authorization API base URL, used in calls to this API") cmd.Flags().String(dnsCustomEndpointFlag, "", "DNS API base URL, used in calls to this API") @@ -155,6 +157,8 @@ func configureFlags(cmd *cobra.Command) { cobra.CheckErr(err) err = viper.BindPFlag(config.IdentityProviderCustomEndpointKey, cmd.Flags().Lookup(identityProviderCustomEndpointFlag)) cobra.CheckErr(err) + err = viper.BindPFlag(config.IdentityProviderCustomClientIdKey, cmd.Flags().Lookup(identityProviderCustomClientIdFlag)) + cobra.CheckErr(err) err = viper.BindPFlag(config.ArgusCustomEndpointKey, cmd.Flags().Lookup(argusCustomEndpointFlag)) cobra.CheckErr(err) diff --git a/internal/cmd/config/unset/unset.go b/internal/cmd/config/unset/unset.go index 9c6d95328..a879dd5b0 100644 --- a/internal/cmd/config/unset/unset.go +++ b/internal/cmd/config/unset/unset.go @@ -22,6 +22,7 @@ const ( sessionTimeLimitFlag = "session-time-limit" identityProviderCustomEndpointFlag = "identity-provider-custom-endpoint" + identityProviderCustomClientIdFlag = "identity-provider-custom-client-id" argusCustomEndpointFlag = "argus-custom-endpoint" authorizationCustomEndpointFlag = "authorization-custom-endpoint" @@ -54,6 +55,7 @@ type inputModel struct { SessionTimeLimit bool IdentityProviderCustomEndpoint bool + IdentityProviderCustomClientID bool ArgusCustomEndpoint bool AuthorizationCustomEndpoint bool @@ -117,6 +119,9 @@ func NewCmd(p *print.Printer) *cobra.Command { if model.IdentityProviderCustomEndpoint { viper.Set(config.IdentityProviderCustomEndpointKey, "") } + if model.IdentityProviderCustomClientID { + viper.Set(config.IdentityProviderCustomClientIdKey, "") + } if model.ArgusCustomEndpoint { viper.Set(config.ArgusCustomEndpointKey, "") @@ -201,6 +206,7 @@ func configureFlags(cmd *cobra.Command) { cmd.Flags().Bool(sessionTimeLimitFlag, false, fmt.Sprintf("Maximum time before authentication is required again. If unset, defaults to %s", config.SessionTimeLimitDefault)) cmd.Flags().Bool(identityProviderCustomEndpointFlag, false, "Identity Provider base URL. If unset, uses the default base URL") + cmd.Flags().Bool(identityProviderCustomClientIdFlag, false, "Identity Provider client ID, used for user authentication") cmd.Flags().Bool(argusCustomEndpointFlag, false, "Argus API base URL. If unset, uses the default base URL") cmd.Flags().Bool(authorizationCustomEndpointFlag, false, "Authorization API base URL. If unset, uses the default base URL") @@ -234,6 +240,7 @@ func parseInput(p *print.Printer, cmd *cobra.Command) *inputModel { SessionTimeLimit: flags.FlagToBoolValue(p, cmd, sessionTimeLimitFlag), IdentityProviderCustomEndpoint: flags.FlagToBoolValue(p, cmd, identityProviderCustomEndpointFlag), + IdentityProviderCustomClientID: flags.FlagToBoolValue(p, cmd, identityProviderCustomClientIdFlag), ArgusCustomEndpoint: flags.FlagToBoolValue(p, cmd, argusCustomEndpointFlag), AuthorizationCustomEndpoint: flags.FlagToBoolValue(p, cmd, authorizationCustomEndpointFlag), diff --git a/internal/cmd/config/unset/unset_test.go b/internal/cmd/config/unset/unset_test.go index ad3d75199..9f64136f5 100644 --- a/internal/cmd/config/unset/unset_test.go +++ b/internal/cmd/config/unset/unset_test.go @@ -18,6 +18,7 @@ func fixtureFlagValues(mods ...func(flagValues map[string]bool)) map[string]bool sessionTimeLimitFlag: true, identityProviderCustomEndpointFlag: true, + identityProviderCustomClientIdFlag: true, argusCustomEndpointFlag: true, authorizationCustomEndpointFlag: true, @@ -53,6 +54,7 @@ func fixtureInputModel(mods ...func(model *inputModel)) *inputModel { SessionTimeLimit: true, IdentityProviderCustomEndpoint: true, + IdentityProviderCustomClientID: true, ArgusCustomEndpoint: true, AuthorizationCustomEndpoint: true, @@ -104,6 +106,7 @@ func TestParseInput(t *testing.T) { model.SessionTimeLimit = false model.IdentityProviderCustomEndpoint = false + model.IdentityProviderCustomClientID = false model.ArgusCustomEndpoint = false model.AuthorizationCustomEndpoint = false @@ -155,6 +158,16 @@ func TestParseInput(t *testing.T) { model.IdentityProviderCustomEndpoint = false }), }, + { + description: "identity provider custom client id empty", + flagValues: fixtureFlagValues(func(flagValues map[string]bool) { + flagValues[identityProviderCustomClientIdFlag] = false + }), + isValid: true, + expectedModel: fixtureInputModel(func(model *inputModel) { + model.IdentityProviderCustomClientID = false + }), + }, { description: "argus custom endpoint empty", flagValues: fixtureFlagValues(func(flagValues map[string]bool) { diff --git a/internal/pkg/auth/user_login.go b/internal/pkg/auth/user_login.go index ad519f628..f09b776d8 100644 --- a/internal/pkg/auth/user_login.go +++ b/internal/pkg/auth/user_login.go @@ -24,7 +24,7 @@ import ( const ( defaultIDPEndpoint = "https://accounts.stackit.cloud/oauth/v2" - cliClientID = "stackit-cli-0000-0000-000000000001" + defaultCLIClientID = "stackit-cli-0000-0000-000000000001" loginSuccessPath = "/login-successful" stackitLandingPage = "https://www.stackit.de" @@ -58,6 +58,20 @@ func AuthorizeUser(p *print.Printer, isReauthentication bool) error { } } + idpClientID, err := getIDPClientID() + if err != nil { + return err + } + if idpClientID != defaultCLIClientID { + p.Warn("You are using a custom client ID (%s) for authentication.\n", idpClientID) + err := p.PromptForEnter("Press Enter to proceed with the login...") + if err != nil { + return err + } + } + + // TODO here implement the client id get client id usage like above + if isReauthentication { err := p.PromptForEnter("Your session has expired, press Enter to login again...") if err != nil { @@ -86,7 +100,7 @@ func AuthorizeUser(p *print.Printer, isReauthentication bool) error { } conf := &oauth2.Config{ - ClientID: cliClientID, + ClientID: idpClientID, Endpoint: oauth2.Endpoint{ AuthURL: fmt.Sprintf("%s/authorize", idpEndpoint), }, @@ -131,7 +145,7 @@ func AuthorizeUser(p *print.Printer, isReauthentication bool) error { p.Debug(print.DebugLevel, "trading authorization code for access and refresh tokens") // Trade the authorization code and the code verifier for access and refresh tokens - accessToken, refreshToken, err := getUserAccessAndRefreshTokens(idpEndpoint, cliClientID, codeVerifier, code, redirectURL) + accessToken, refreshToken, err := getUserAccessAndRefreshTokens(idpEndpoint, idpClientID, codeVerifier, code, redirectURL) if err != nil { errServer = fmt.Errorf("retrieve tokens: %w", err) return @@ -207,6 +221,7 @@ func AuthorizeUser(p *print.Printer, isReauthentication bool) error { p.Debug(print.DebugLevel, "opening browser for authentication") p.Debug(print.DebugLevel, "using authentication server on %s", idpEndpoint) + p.Debug(print.DebugLevel, "using client ID %s for authentication ", idpClientID) // Open a browser window to the authorizationURL err = openBrowser(authorizationURL) diff --git a/internal/pkg/auth/user_token_flow.go b/internal/pkg/auth/user_token_flow.go index a15c33c17..530672882 100644 --- a/internal/pkg/auth/user_token_flow.go +++ b/internal/pkg/auth/user_token_flow.go @@ -161,6 +161,11 @@ func buildRequestToRefreshTokens(utf *userTokenFlow) (*http.Request, error) { return nil, err } + idpClientID, err := getIDPClientID() + if err != nil { + return nil, err + } + req, err := http.NewRequest( http.MethodPost, fmt.Sprintf("%s/token", idpEndpoint), @@ -171,7 +176,7 @@ func buildRequestToRefreshTokens(utf *userTokenFlow) (*http.Request, error) { } reqQuery := url.Values{} reqQuery.Set("grant_type", "refresh_token") - reqQuery.Set("client_id", cliClientID) + reqQuery.Set("client_id", idpClientID) reqQuery.Set("refresh_token", utf.refreshToken) reqQuery.Set("token_format", "jwt") req.URL.RawQuery = reqQuery.Encode() diff --git a/internal/pkg/auth/utils.go b/internal/pkg/auth/utils.go index cdb38c137..a60e25a5c 100644 --- a/internal/pkg/auth/utils.go +++ b/internal/pkg/auth/utils.go @@ -22,3 +22,14 @@ func getIDPEndpoint() (string, error) { return idpEndpoint, nil } + +func getIDPClientID() (string, error) { + idpClientID := defaultCLIClientID + + customIDPClientID := viper.GetString(config.IdentityProviderCustomClientIdKey) + if customIDPClientID != "" { + idpClientID = customIDPClientID + } + + return idpClientID, nil +} diff --git a/internal/pkg/auth/utils_test.go b/internal/pkg/auth/utils_test.go index 375bd8462..c345c4d82 100644 --- a/internal/pkg/auth/utils_test.go +++ b/internal/pkg/auth/utils_test.go @@ -52,3 +52,44 @@ func TestGetIDPEndpoint(t *testing.T) { }) } } + +func TestGetIDPClientID(t *testing.T) { + tests := []struct { + name string + idpCustomClientID string + isValid bool + expected string + }{ + { + name: "custom client ID specified", + idpCustomClientID: "custom-client-id", + isValid: true, + expected: "custom-client-id", + }, + { + name: "custom client ID not specified", + idpCustomClientID: "", + isValid: true, + expected: defaultCLIClientID, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + viper.Reset() + viper.Set(config.IdentityProviderCustomClientIdKey, tt.idpCustomClientID) + + got, err := getIDPClientID() + + if tt.isValid && err != nil { + t.Fatalf("expected no error, got %v", err) + } + if !tt.isValid && err == nil { + t.Fatalf("expected error, got none") + } + + if got != tt.expected { + t.Fatalf("expected idp client ID %q, got %q", tt.expected, got) + } + }) + } +} diff --git a/internal/pkg/config/config.go b/internal/pkg/config/config.go index a898d957d..548827ddb 100644 --- a/internal/pkg/config/config.go +++ b/internal/pkg/config/config.go @@ -18,6 +18,7 @@ const ( VerbosityKey = "verbosity" IdentityProviderCustomEndpointKey = "identity_provider_custom_endpoint" + IdentityProviderCustomClientIdKey = "identity_provider_custom_client_id" ArgusCustomEndpointKey = "argus_custom_endpoint" AuthorizationCustomEndpointKey = "authorization_custom_endpoint" @@ -67,6 +68,7 @@ var ConfigKeys = []string{ VerbosityKey, IdentityProviderCustomEndpointKey, + IdentityProviderCustomClientIdKey, DNSCustomEndpointKey, LoadBalancerCustomEndpointKey, @@ -148,6 +150,7 @@ func setConfigDefaults() { viper.SetDefault(ProjectIdKey, "") viper.SetDefault(SessionTimeLimitKey, SessionTimeLimitDefault) viper.SetDefault(IdentityProviderCustomEndpointKey, "") + viper.SetDefault(IdentityProviderCustomClientIdKey, "") viper.SetDefault(DNSCustomEndpointKey, "") viper.SetDefault(ArgusCustomEndpointKey, "") viper.SetDefault(AuthorizationCustomEndpointKey, "") From 091bde25c934fbc8763b343a593f4dd8f13630cb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Go=CC=88kc=CC=A7e=20Go=CC=88k=20Klingel?= Date: Fri, 16 Aug 2024 12:29:08 +0200 Subject: [PATCH 2/2] clean up comment --- internal/pkg/auth/user_login.go | 2 -- 1 file changed, 2 deletions(-) diff --git a/internal/pkg/auth/user_login.go b/internal/pkg/auth/user_login.go index f09b776d8..d2d767fd1 100644 --- a/internal/pkg/auth/user_login.go +++ b/internal/pkg/auth/user_login.go @@ -70,8 +70,6 @@ func AuthorizeUser(p *print.Printer, isReauthentication bool) error { } } - // TODO here implement the client id get client id usage like above - if isReauthentication { err := p.PromptForEnter("Your session has expired, press Enter to login again...") if err != nil {