diff --git a/internal/cmd/root.go b/internal/cmd/root.go index 62a5b0b6e..b35dd546e 100644 --- a/internal/cmd/root.go +++ b/internal/cmd/root.go @@ -56,18 +56,24 @@ func NewRootCmd(version, date string, p *print.Printer) *cobra.Command { configFilePath := viper.ConfigFileUsed() p.Debug(print.DebugLevel, "configuration is persisted and read from: %s", configFilePath) - activeProfile, err := config.GetProfile() + profileSet, activeProfile, configMethod, err := config.GetConfiguredProfile() if err != nil { - return fmt.Errorf("get profile: %w", err) + return fmt.Errorf("get configured profile: %w", err) } - profileExists, err := config.ProfileExists(activeProfile) - if err != nil { - return fmt.Errorf("check if profile exists: %w", err) - } - if !profileExists { - p.Warn("active profile does not exist, the default profile configuration will be used\n") + p.Debug(print.DebugLevel, "read configuration profile %q via %s", profileSet, configMethod) + + if activeProfile != profileSet { + if configMethod == "" { + p.Debug(print.DebugLevel, "no profile is configured in env var or profile file") + } else { + p.Debug(print.DebugLevel, "the configured profile %q does not exist: folder %q is missing", profileSet, config.GetProfileFolderPath(profileSet)) + } + p.Debug(print.DebugLevel, "the %q profile will be used", activeProfile) + + p.Warn("configured profile %q does not exist, the %q profile configuration will be used\n", profileSet, activeProfile) } + p.Debug(print.DebugLevel, "active configuration profile: %s", activeProfile) configKeys := viper.AllSettings() diff --git a/internal/pkg/auth/storage.go b/internal/pkg/auth/storage.go index 73bef1189..fb511dbdb 100644 --- a/internal/pkg/auth/storage.go +++ b/internal/pkg/auth/storage.go @@ -88,18 +88,7 @@ func setAuthFieldInEncodedTextFile(activeProfile string, key authFieldKey, value if err != nil { return err } - - configDir, err := os.UserConfigDir() - if err != nil { - return fmt.Errorf("get config dir: %w", err) - } - - profileTextFileFolderName := textFileFolderName - if activeProfile != "" { - profileTextFileFolderName = filepath.Join(textFileFolderName, activeProfile) - } - - textFileDir := filepath.Join(configDir, profileTextFileFolderName) + textFileDir := config.GetProfileFolderPath(activeProfile) textFilePath := filepath.Join(textFileDir, textFileName) contentEncoded, err := os.ReadFile(textFilePath) @@ -178,17 +167,7 @@ func getAuthFieldFromEncodedTextFile(activeProfile string, key authFieldKey) (st return "", err } - configDir, err := os.UserConfigDir() - if err != nil { - return "", fmt.Errorf("get config dir: %w", err) - } - - profileTextFileFolderName := textFileFolderName - if activeProfile != "" { - profileTextFileFolderName = filepath.Join(textFileFolderName, activeProfile) - } - - textFileDir := filepath.Join(configDir, profileTextFileFolderName) + textFileDir := config.GetProfileFolderPath(activeProfile) textFilePath := filepath.Join(textFileDir, textFileName) contentEncoded, err := os.ReadFile(textFilePath) @@ -215,20 +194,10 @@ func getAuthFieldFromEncodedTextFile(activeProfile string, key authFieldKey) (st // If it doesn't, creates it with the content "{}" encoded. // If it does, does nothing (and returns nil). func createEncodedTextFile(activeProfile string) error { - configDir, err := os.UserConfigDir() - if err != nil { - return fmt.Errorf("get config dir: %w", err) - } - - profileTextFileFolderName := textFileFolderName - if activeProfile != "" { - profileTextFileFolderName = filepath.Join(textFileFolderName, activeProfile) - } - - textFileDir := filepath.Join(configDir, profileTextFileFolderName) + textFileDir := config.GetProfileFolderPath(activeProfile) textFilePath := filepath.Join(textFileDir, textFileName) - err = os.MkdirAll(textFileDir, os.ModePerm) + err := os.MkdirAll(textFileDir, os.ModePerm) if err != nil { return fmt.Errorf("create file dir: %w", err) } diff --git a/internal/pkg/auth/storage_test.go b/internal/pkg/auth/storage_test.go index 4d1593020..f8e6b4dd6 100644 --- a/internal/pkg/auth/storage_test.go +++ b/internal/pkg/auth/storage_test.go @@ -12,6 +12,7 @@ import ( "github.com/zalando/go-keyring" "github.com/stackitcloud/stackit-cli/internal/pkg/config" + "github.com/stackitcloud/stackit-cli/internal/pkg/print" ) func TestSetGetAuthField(t *testing.T) { @@ -183,7 +184,7 @@ func TestSetGetAuthFieldKeyring(t *testing.T) { }{ { description: "simple assignments with default profile", - activeProfile: "", + activeProfile: config.DefaultProfileName, valueAssignments: []valueAssignment{ { key: testField1, @@ -201,7 +202,7 @@ func TestSetGetAuthFieldKeyring(t *testing.T) { }, { description: "overlapping assignments with default profile", - activeProfile: "", + activeProfile: config.DefaultProfileName, valueAssignments: []valueAssignment{ { key: testField1, @@ -267,6 +268,12 @@ func TestSetGetAuthFieldKeyring(t *testing.T) { t.Run(tt.description, func(t *testing.T) { keyring.MockInit() + // Make sure profile name is valid + err := config.ValidateProfile(tt.activeProfile) + if err != nil { + t.Fatalf("Profile name \"%s\" is invalid: %v", tt.activeProfile, err) + } + for _, assignment := range tt.valueAssignments { err := setAuthFieldInKeyring(tt.activeProfile, assignment.key, assignment.value) if err != nil { @@ -317,7 +324,7 @@ func TestSetGetAuthFieldEncodedTextFile(t *testing.T) { }{ { description: "simple assignments with default profile", - activeProfile: "", + activeProfile: config.DefaultProfileName, valueAssignments: []valueAssignment{ { key: testField1, @@ -335,7 +342,7 @@ func TestSetGetAuthFieldEncodedTextFile(t *testing.T) { }, { description: "overlapping assignments with default profile", - activeProfile: "", + activeProfile: config.DefaultProfileName, valueAssignments: []valueAssignment{ { key: testField1, @@ -399,6 +406,26 @@ func TestSetGetAuthFieldEncodedTextFile(t *testing.T) { for _, tt := range tests { t.Run(tt.description, func(t *testing.T) { + // Make sure profile name is valid + err := config.ValidateProfile(tt.activeProfile) + if err != nil { + t.Fatalf("Profile name \"%s\" is invalid: %v", tt.activeProfile, err) + } + + // Create profile if it does not exist + // Will be deleted at the end of the test + profileExists, err := config.ProfileExists(tt.activeProfile) + if err != nil { + t.Fatalf("Failed to check if profile exists: %v", err) + } + if !profileExists { + p := print.NewPrinter() + err := config.CreateProfile(p, tt.activeProfile, true, true) + if err != nil { + t.Fatalf("Failed to create profile: %v", err) + } + } + for _, assignment := range tt.valueAssignments { err := setAuthFieldInEncodedTextFile(tt.activeProfile, assignment.key, assignment.value) if err != nil { @@ -424,6 +451,11 @@ func TestSetGetAuthFieldEncodedTextFile(t *testing.T) { t.Errorf("Post-test cleanup failed: remove field \"%s\" from text file: %v. Please remove it manually", key, err) } } + + err = deleteAuthFieldProfile(tt.activeProfile, profileExists) + if err != nil { + t.Errorf("Post-test cleanup failed: remove profile \"%s\": %v. Please remove it manually", tt.activeProfile, err) + } }) } } @@ -443,17 +475,7 @@ func deleteAuthFieldInEncodedTextFile(activeProfile string, key authFieldKey) er return err } - configDir, err := os.UserConfigDir() - if err != nil { - return fmt.Errorf("get config dir: %w", err) - } - - profileTextFileFolderName := textFileFolderName - if activeProfile != "" { - profileTextFileFolderName = filepath.Join(textFileFolderName, activeProfile) - } - - textFileDir := filepath.Join(configDir, profileTextFileFolderName) + textFileDir := config.GetProfileFolderPath(activeProfile) textFilePath := filepath.Join(textFileDir, textFileName) contentEncoded, err := os.ReadFile(textFilePath) @@ -483,3 +505,15 @@ func deleteAuthFieldInEncodedTextFile(activeProfile string, key authFieldKey) er } return nil } + +func deleteAuthFieldProfile(activeProfile string, profileExisted bool) error { + textFileDir := config.GetProfileFolderPath(activeProfile) + if !profileExisted { + // Remove the entire directory if the profile does not exist + err := os.RemoveAll(textFileDir) + if err != nil { + return fmt.Errorf("remove directory: %w", err) + } + } + return nil +} diff --git a/internal/pkg/config/config.go b/internal/pkg/config/config.go index b6ab86ea3..eb1ed73a2 100644 --- a/internal/pkg/config/config.go +++ b/internal/pkg/config/config.go @@ -34,15 +34,15 @@ const ( ServiceAccountCustomEndpointKey = "service_account_custom_endpoint" SKECustomEndpointKey = "ske_custom_endpoint" - ProjectNameKey = "project_name" + ProjectNameKey = "project_name" + DefaultProfileName = "default" AsyncDefault = false SessionTimeLimitDefault = "2h" ) const ( - configFolder = "stackit" - defaultProfileName = "default" + configFolder = "stackit" configFileName = "cli-config" configFileExtension = "json" @@ -83,21 +83,15 @@ var configFolderPath string var profileFilePath string func InitConfig() { - configDir, err := os.UserConfigDir() - cobra.CheckErr(err) - - defaultConfigFolderPath = filepath.Join(configDir, configFolder) - profileFilePath = filepath.Join(defaultConfigFolderPath, fmt.Sprintf("%s.%s", profileFileName, profileFileExtension)) // Profile file path is in the default config folder + defaultConfigFolderPath = getInitialConfigDir() + profileFilePath = getInitialProfileFilePath() // Profile file path is in the default config folder configProfile, err := GetProfile() cobra.CheckErr(err) - configFolderPath = defaultConfigFolderPath - if configProfile != defaultProfileName { - configFolderPath = filepath.Join(configFolderPath, profileRootFolder, configProfile) // If a profile is set, use the profile config folder - } + configFolderPath = GetProfileFolderPath(configProfile) - configFilePath := filepath.Join(configFolderPath, fmt.Sprintf("%s.%s", configFileName, configFileExtension)) + configFilePath := getConfigFilePath(configFolderPath) // This hack is required to allow creating the config file with `viper.WriteConfig` // see https://github.com/spf13/viper/issues/851#issuecomment-789393451 @@ -151,3 +145,22 @@ func setConfigDefaults() { viper.SetDefault(ServiceAccountCustomEndpointKey, "") viper.SetDefault(SKECustomEndpointKey, "") } + +func getConfigFilePath(configFolder string) string { + return filepath.Join(configFolder, fmt.Sprintf("%s.%s", configFileName, configFileExtension)) +} + +func getInitialConfigDir() string { + configDir, err := os.UserConfigDir() + cobra.CheckErr(err) + + return filepath.Join(configDir, configFolder) +} + +func getInitialProfileFilePath() string { + configFolderPath := defaultConfigFolderPath + if configFolderPath == "" { + configFolderPath = getInitialConfigDir() + } + return filepath.Join(configFolderPath, fmt.Sprintf("%s.%s", profileFileName, profileFileExtension)) +} diff --git a/internal/pkg/config/config_test.go b/internal/pkg/config/config_test.go index b3483a635..7c9954117 100644 --- a/internal/pkg/config/config_test.go +++ b/internal/pkg/config/config_test.go @@ -1,6 +1,7 @@ package config import ( + "fmt" "os" "path/filepath" "testing" @@ -69,3 +70,58 @@ func TestWrite(t *testing.T) { }) } } + +func TestGetInitialConfigDir(t *testing.T) { + tests := []struct { + description string + }{ + { + description: "base", + }, + } + + for _, tt := range tests { + t.Run(tt.description, func(t *testing.T) { + actual := getInitialConfigDir() + + userConfig, err := os.UserConfigDir() + if err != nil { + t.Fatalf("expected error to be nil, got %v", err) + } + + expected := filepath.Join(userConfig, "stackit") + if actual != expected { + t.Fatalf("expected %s, got %s", expected, actual) + } + }) + } +} + +func TestGetInitialProfileFilePath(t *testing.T) { + tests := []struct { + description string + configFolderPath string + }{ + { + description: "base", + configFolderPath: getInitialConfigDir(), + }, + { + description: "empty config folder path", + configFolderPath: "", + }, + } + + for _, tt := range tests { + t.Run(tt.description, func(t *testing.T) { + configFolderPath = getInitialConfigDir() + + actual := getInitialProfileFilePath() + + expected := filepath.Join(configFolderPath, fmt.Sprintf("%s.%s", profileFileName, profileFileExtension)) + if actual != expected { + t.Fatalf("expected %s, got %s", expected, actual) + } + }) + } +} diff --git a/internal/pkg/config/profiles.go b/internal/pkg/config/profiles.go index 9723c3862..c96a42dc8 100644 --- a/internal/pkg/config/profiles.go +++ b/internal/pkg/config/profiles.go @@ -14,40 +14,58 @@ import ( const ProfileEnvVar = "STACKIT_CLI_PROFILE" // GetProfile returns the current profile to be used by the CLI. -// // The profile is determined by the value of the STACKIT_CLI_PROFILE environment variable, or, if not set, // by the contents of the profile file in the CLI config folder. -// // If the profile is not set (env var or profile file) or is set but does not exist, it falls back to the default profile. -// // If the profile is not valid, it returns an error. func GetProfile() (string, error) { + _, profile, _, err := GetConfiguredProfile() + if err != nil { + return "", err + } + + return profile, nil +} + +// GetConfiguredProfile returns the profile configured by the user, the profile to be used by the CLI and the method used to configure the profile. +// The profile is determined by the value of the STACKIT_CLI_PROFILE environment variable, or, if not set, +// by the contents of the profile file in the CLI config folder. +// If the configured profile is not set (env var or profile file) or is set but does not exist, it falls back to the default profile. +// The configuration method can be environment variable, profile file or empty if profile is not configured. +// If the profile is not valid, it returns an error. +func GetConfiguredProfile() (configuredProfile, activeProfile, configurationMethod string, err error) { + var configMethod string profile, profileSetInEnv := GetProfileFromEnv() if !profileSetInEnv { contents, exists, err := fileutils.ReadFileIfExists(profileFilePath) if err != nil { - return "", fmt.Errorf("read profile from file: %w", err) + return "", "", "", fmt.Errorf("read profile from file: %w", err) } if !exists { - return defaultProfileName, nil + // No profile set in env or file + return DefaultProfileName, DefaultProfileName, "", nil } profile = contents + configMethod = "profile file" + } else { + configMethod = "environment variable" } // Make sure the profile exists profileExists, err := ProfileExists(profile) if err != nil { - return "", fmt.Errorf("check if profile exists: %w", err) + return "", "", "", fmt.Errorf("check if profile exists: %w", err) } if !profileExists { - return defaultProfileName, nil + // Profile is configured but does not exist + return profile, DefaultProfileName, configMethod, nil } err = ValidateProfile(profile) if err != nil { - return "", fmt.Errorf("validate profile: %w", err) + return "", "", "", fmt.Errorf("validate profile: %w", err) } - return profile, nil + return profile, profile, configMethod, nil } // GetProfileFromEnv returns the profile from the environment variable. @@ -68,13 +86,13 @@ func CreateProfile(p *print.Printer, profile string, setProfile, emptyProfile bo } // Cannot create a profile with the default name - if profile == defaultProfileName { + if profile == DefaultProfileName { return &errors.InvalidProfileNameError{ Profile: profile, } } - configFolderPath = filepath.Join(defaultConfigFolderPath, profileRootFolder, profile) + configFolderPath = GetProfileFolderPath(profile) // Error if the profile already exists _, err = os.Stat(configFolderPath) @@ -124,20 +142,25 @@ func CreateProfile(p *print.Printer, profile string, setProfile, emptyProfile bo // DuplicateProfileConfiguration duplicates the current profile configuration to a new profile. // It copies the config file from the current profile to the new profile. -// If the current profile does not exist, it returns an error. +// If the current profile does not exist, it does nothing. // If the new profile already exists, it will be overwritten. func DuplicateProfileConfiguration(p *print.Printer, currentProfile, newProfile string) error { - var currentConfigFilePath string + currentProfileFolder := GetProfileFolderPath(currentProfile) + currentConfigFilePath := getConfigFilePath(currentProfileFolder) - if currentProfile == defaultProfileName { - currentConfigFilePath = filepath.Join(defaultConfigFolderPath, fmt.Sprintf("%s.%s", configFileName, configFileExtension)) - } else { - currentConfigFilePath = filepath.Join(defaultConfigFolderPath, profileRootFolder, currentProfile, fmt.Sprintf("%s.%s", configFileName, configFileExtension)) - } + newConfigFilePath := getConfigFilePath(configFolderPath) - newConfigFilePath := filepath.Join(configFolderPath, fmt.Sprintf("%s.%s", configFileName, configFileExtension)) + // If the source profile configuration does not exist, do nothing + _, err := os.Stat(currentConfigFilePath) + if err != nil { + if os.IsNotExist(err) { + p.Debug(print.DebugLevel, "current profile %q has no configuration, nothing to duplicate", currentProfile) + return nil + } + return fmt.Errorf("get current profile configuration: %w", err) + } - err := fileutils.CopyFile(currentConfigFilePath, newConfigFilePath) + err = fileutils.CopyFile(currentConfigFilePath, newConfigFilePath) if err != nil { return fmt.Errorf("copy config file: %w", err) } @@ -163,13 +186,17 @@ func SetProfile(p *print.Printer, profile string) error { return &errors.SetInexistentProfile{Profile: profile} } + if profileFilePath == "" { + profileFilePath = getInitialProfileFilePath() + } + err = os.WriteFile(profileFilePath, []byte(profile), os.ModePerm) if err != nil { return fmt.Errorf("write profile to file: %w", err) } p.Debug(print.DebugLevel, "persisted new active profile in: %s", profileFilePath) - configFolderPath = filepath.Join(defaultConfigFolderPath, profile) + configFolderPath = GetProfileFolderPath(profile) p.Debug(print.DebugLevel, "profile %q is now active", profile) return nil @@ -203,7 +230,7 @@ func ValidateProfile(profile string) error { } func ProfileExists(profile string) (bool, error) { - _, err := os.Stat(filepath.Join(defaultConfigFolderPath, profileRootFolder, profile)) + _, err := os.Stat(GetProfileFolderPath(profile)) if err != nil { if os.IsNotExist(err) { return false, nil @@ -212,3 +239,16 @@ func ProfileExists(profile string) (bool, error) { } return true, nil } + +// GetProfileFolderPath returns the path to the folder where the profile configuration is stored. +// If the profile is the default profile, it returns the default config folder path. +func GetProfileFolderPath(profile string) string { + if defaultConfigFolderPath == "" { + defaultConfigFolderPath = getInitialConfigDir() + } + + if profile == DefaultProfileName { + return defaultConfigFolderPath + } + return filepath.Join(defaultConfigFolderPath, profileRootFolder, profile) +} diff --git a/internal/pkg/config/profiles_test.go b/internal/pkg/config/profiles_test.go index e97451aeb..1250a16aa 100644 --- a/internal/pkg/config/profiles_test.go +++ b/internal/pkg/config/profiles_test.go @@ -1,6 +1,9 @@ package config -import "testing" +import ( + "path/filepath" + "testing" +) func TestValidateProfile(t *testing.T) { tests := []struct { @@ -57,3 +60,48 @@ func TestValidateProfile(t *testing.T) { }) } } + +func TestGetProfileFolderPath(t *testing.T) { + tests := []struct { + description string + defaultConfigFolderNotSet bool + profile string + expected string + }{ + { + description: "default profile", + profile: DefaultProfileName, + expected: getInitialConfigDir(), + }, + { + description: "default profile, default config folder not set", + defaultConfigFolderNotSet: true, + profile: DefaultProfileName, + expected: getInitialConfigDir(), + }, + { + description: "custom profile", + profile: "my-profile", + expected: filepath.Join(getInitialConfigDir(), profileRootFolder, "my-profile"), + }, + { + description: "custom profile, default config folder not set", + defaultConfigFolderNotSet: true, + profile: "my-profile", + expected: filepath.Join(getInitialConfigDir(), profileRootFolder, "my-profile"), + }, + } + + for _, tt := range tests { + t.Run(tt.description, func(t *testing.T) { + defaultConfigFolderPath = getInitialConfigDir() + if tt.defaultConfigFolderNotSet { + defaultConfigFolderPath = "" + } + actual := GetProfileFolderPath(tt.profile) + if actual != tt.expected { + t.Errorf("expected profile folder path to be %q but got %q", tt.expected, actual) + } + }) + } +}