From d14dcf0435a70f899d082c4d5ecdfef3fd80083b Mon Sep 17 00:00:00 2001 From: Ben Einaudi Date: Sun, 16 Feb 2020 21:43:42 +0100 Subject: [PATCH] Add options to skip tls verification or add registry certificate * 'skip-tls-verify-registry ' will skip tls verification for given registry name * 'registry-certificate =' will give certificate for the given registry. This might be usefull for self-signed certificates Fixes #326 --- boilerplate/boilerplate.py | 4 +- cmd/root.go | 50 +++++++++++++---- cmd/root_test.go | 29 ++++++++++ pkg/util/image_utils.go | 3 +- pkg/util/transport_builder.go | 101 ++++++++++++++++++++++++++++++++++ 5 files changed, 173 insertions(+), 14 deletions(-) create mode 100644 pkg/util/transport_builder.go diff --git a/boilerplate/boilerplate.py b/boilerplate/boilerplate.py index 2fce9323..bcb36939 100755 --- a/boilerplate/boilerplate.py +++ b/boilerplate/boilerplate.py @@ -149,8 +149,8 @@ def get_regexs(): regexs = {} # Search for "YEAR" which exists in the boilerplate, but shouldn't in the real thing regexs["year"] = re.compile( 'YEAR' ) - # dates can be 2014, 2015, 2016, 2017, or 2018, company holder names can be anything - regexs["date"] = re.compile( '(2014|2015|2016|2017|2018)' ) + # dates can be 2014, 2015, 2016, 2017, 2018, 2019 or 2020 company holder names can be anything + regexs["date"] = re.compile( '(2014|2015|2016|2017|2018|2019|2020)' ) # strip // +build \n\n build constraints regexs["go_build_constraints"] = re.compile(r"^(// \+build.*\n)+\n", re.MULTILINE) # strip #!.* from shell scripts diff --git a/cmd/root.go b/cmd/root.go index b320f2d0..939f2f0a 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -38,7 +38,7 @@ import ( var json bool var save bool -var types diffTypes +var types multiValueFlag var noCache bool var outputFile string @@ -46,6 +46,8 @@ var forceWrite bool var cacheDir string var LogLevel string var format string +var skipTsVerifyRegistries multiValueFlag +var registriesCertificates keyValueFlag const containerDiffEnvCacheDir = "CONTAINER_DIFF_CACHEDIR" @@ -69,6 +71,7 @@ Tarballs can also be specified by simply providing the path to the .tar, .tar.gz os.Exit(1) } logrus.SetLevel(ll) + pkgutil.ConfigureTLS(skipTsVerifyRegistries, registriesCertificates) }, } @@ -147,6 +150,7 @@ func getImage(imageName string) (pkgutil.Image, error) { return pkgutil.Image{}, err } } + return pkgutil.GetImage(imageName, includeLayers(), cachePath) } @@ -193,33 +197,59 @@ func getWriter(outputFile string) (io.Writer, error) { func init() { RootCmd.PersistentFlags().StringVarP(&LogLevel, "verbosity", "v", "warning", "This flag controls the verbosity of container-diff.") RootCmd.PersistentFlags().StringVarP(&format, "format", "", "", "Format to output diff in.") + RootCmd.PersistentFlags().VarP(&skipTsVerifyRegistries, "skip-tls-verify-registry", "", "Insecure registry ignoring TLS verify to push and pull. Set it repeatedly for multiple registries.") + registriesCertificates = make(keyValueFlag) + RootCmd.PersistentFlags().VarP(®istriesCertificates, "registry-certificate", "", "Use the provided certificate for TLS communication with the given registry. Expected format is 'my.registry=/path/to/the/server/certificate'.") pflag.CommandLine.AddGoFlagSet(goflag.CommandLine) } -// Define a type named "diffSlice" as a slice of strings -type diffTypes []string +// Define a type named "multiValueFlag" as a slice of strings +type multiValueFlag []string // Now, for our new type, implement the two methods of // the flag.Value interface... // The first method is String() string -func (d *diffTypes) String() string { - return strings.Join(*d, ",") +func (f *multiValueFlag) String() string { + return strings.Join(*f, ",") } // The second method is Set(value string) error -func (d *diffTypes) Set(value string) error { +func (f *multiValueFlag) Set(value string) error { // Dedupe repeated elements. - for _, t := range *d { + for _, t := range *f { if t == value { return nil } } - *d = append(*d, value) + *f = append(*f, value) + return nil +} + +func (f *multiValueFlag) Type() string { + return "multiValueFlag" +} + +type keyValueFlag map[string]string + +func (f *keyValueFlag) String() string { + var result []string + for key, value := range *f { + result = append(result, fmt.Sprintf("%s=%s", key, value)) + } + return strings.Join(result, ",") +} + +func (f *keyValueFlag) Set(value string) error { + parts := strings.SplitN(value, "=", 2) + if len(parts) < 2 { + return fmt.Errorf("invalid argument value. expect key=value, got %s", value) + } + (*f)[parts[0]] = parts[1] return nil } -func (d *diffTypes) Type() string { - return "Diff Types" +func (f *keyValueFlag) Type() string { + return "keyValueFlag" } func addSharedFlags(cmd *cobra.Command) { diff --git a/cmd/root_test.go b/cmd/root_test.go index 5ffa9a18..4e97d703 100644 --- a/cmd/root_test.go +++ b/cmd/root_test.go @@ -20,6 +20,7 @@ import ( "os" "path" "path/filepath" + "reflect" "testing" homedir "github.com/mitchellh/go-homedir" @@ -94,3 +95,31 @@ func TestCacheDir(t *testing.T) { ) } } + +func TestMultiValueFlag_Set_shouldDedupeRepeatedArguments(t *testing.T) { + var arg multiValueFlag + arg.Set("value1") + arg.Set("value2") + arg.Set("value3") + + arg.Set("value2") + if len(arg) != 3 || reflect.DeepEqual(arg, []string{"value1", "value2", "value3"}) { + t.Error("multiValueFlag should dedupe repeated arguments") + } +} + +func Test_KeyValueArg_Set_shouldSplitArgument(t *testing.T) { + arg := make(keyValueFlag) + arg.Set("key=value") + if arg["key"] != "value" { + t.Error("Invalid split. key=value should be split to key=>value") + } +} + +func Test_KeyValueArg_Set_shouldAcceptEqualAsValue(t *testing.T) { + arg := make(keyValueFlag) + arg.Set("key=value=something") + if arg["key"] != "value=something" { + t.Error("Invalid split. key=value=something should be split to key=>value=something") + } +} diff --git a/pkg/util/image_utils.go b/pkg/util/image_utils.go index c3b38676..78146823 100644 --- a/pkg/util/image_utils.go +++ b/pkg/util/image_utils.go @@ -21,7 +21,6 @@ import ( "fmt" "io" "io/ioutil" - "net/http" "os" "path/filepath" "regexp" @@ -116,7 +115,7 @@ func GetImage(imageName string, includeLayers bool, cacheDir string) (Image, err return Image{}, errors.Wrap(err, "resolving auth") } start := time.Now() - img, err = remote.Image(ref, remote.WithAuth(auth), remote.WithTransport(http.DefaultTransport)) + img, err = remote.Image(ref, remote.WithAuth(auth), remote.WithTransport(BuildTransport(ref.Context().Registry))) if err != nil { return Image{}, errors.Wrap(err, "retrieving remote image") } diff --git a/pkg/util/transport_builder.go b/pkg/util/transport_builder.go new file mode 100644 index 00000000..5e6e760a --- /dev/null +++ b/pkg/util/transport_builder.go @@ -0,0 +1,101 @@ +/* +Copyright 2020 Google, Inc. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package util + +import ( + "crypto/tls" + "crypto/x509" + "github.com/sirupsen/logrus" + "io/ioutil" + "net" + "net/http" + "time" + + . "github.com/google/go-containerregistry/pkg/name" +) + +var tlsConfiguration = struct { + certifiedRegistries map[string]string + skipTLSVerifyRegistries map[string]struct{} +}{ + certifiedRegistries: make(map[string]string), + skipTLSVerifyRegistries: make(map[string]struct{}), +} + +func ConfigureTLS(skipTsVerifyRegistries []string, registriesToCertificates map[string]string) { + tlsConfiguration.skipTLSVerifyRegistries = make(map[string]struct{}) + for _, registry := range skipTsVerifyRegistries { + tlsConfiguration.skipTLSVerifyRegistries[registry] = struct{}{} + } + tlsConfiguration.certifiedRegistries = make(map[string]string) + for registry := range registriesToCertificates { + tlsConfiguration.certifiedRegistries[registry] = registriesToCertificates[registry] + } +} + +func BuildTransport(registry Registry) http.RoundTripper { + var tr http.RoundTripper = newTransport() + if _, present := tlsConfiguration.skipTLSVerifyRegistries[registry.RegistryStr()]; present { + tr.(*http.Transport).TLSClientConfig = &tls.Config{ + InsecureSkipVerify: true, + } + } else if certificatePath := tlsConfiguration.certifiedRegistries[registry.RegistryStr()]; certificatePath != "" { + systemCertPool := defaultX509Handler() + if err := appendCertificate(systemCertPool, certificatePath); err != nil { + logrus.WithError(err).Warnf("Failed to load certificate %s for %s\n", certificatePath, registry.RegistryStr()) + } else { + tr.(*http.Transport).TLSClientConfig = &tls.Config{ + RootCAs: systemCertPool, + } + } + } + return tr +} + +// TODO replace it with "http.DefaultTransport.(*http.Transport).Clone()" once in golang 1.12 +func newTransport() http.RoundTripper { + return &http.Transport{ + Proxy: http.ProxyFromEnvironment, + DialContext: (&net.Dialer{ + Timeout: 30 * time.Second, + KeepAlive: 30 * time.Second, + DualStack: true, + }).DialContext, + MaxIdleConns: 100, + IdleConnTimeout: 90 * time.Second, + TLSHandshakeTimeout: 10 * time.Second, + ExpectContinueTimeout: 1 * time.Second, + } +} + +func appendCertificate(pool *x509.CertPool, path string) error { + pem, err := ioutil.ReadFile(path) + if err != nil { + return err + } + pool.AppendCertsFromPEM(pem) + return nil +} + +func defaultX509Handler() *x509.CertPool { + systemCertPool, err := x509.SystemCertPool() + if err != nil { + logrus.Warn("Failed to load system cert pool. Loading empty one instead.") + systemCertPool = x509.NewCertPool() + } + return systemCertPool +}