Skip to content

Commit 71577a5

Browse files
tonistiigicrazy-max
authored andcommitted
source: extract SafeFileName into shared pathutil package
Move safeFileName from source/http to source/util/pathutil and apply it to the containerblob source as well. Harden containerblob/pull.go to use os.OpenRoot for file writes, preventing path traversal via crafted filenames. Signed-off-by: Tonis Tiigi <tonistiigi@gmail.com> (cherry picked from commit 3d6e587655d72c343f6fdc7268480a900ba45b0c) (cherry picked from commit 45a6358b084f95ca715376935759c432840bc7bf)
1 parent df43783 commit 71577a5

File tree

6 files changed

+92
-70
lines changed

6 files changed

+92
-70
lines changed

source/containerblob/pull.go

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@ import (
77
"fmt"
88
"io"
99
"os"
10-
"path/filepath"
1110
"time"
1211

1312
"github.com/containerd/containerd/v2/core/remotes"
@@ -18,6 +17,7 @@ import (
1817
"github.com/moby/buildkit/snapshot"
1918
"github.com/moby/buildkit/solver"
2019
srctypes "github.com/moby/buildkit/source/types"
20+
"github.com/moby/buildkit/source/util/pathutil"
2121
"github.com/moby/buildkit/util/contentutil"
2222
"github.com/moby/buildkit/util/iohelper"
2323
"github.com/moby/buildkit/util/resolver"
@@ -225,10 +225,17 @@ func (p *puller) Snapshot(ctx context.Context, jobCtx solver.JobContext) (ir cac
225225
fn := p.id.Filename
226226
if fn == "" {
227227
fn = p.dgst.Hex()
228+
} else {
229+
fn = pathutil.SafeFileName(fn)
230+
}
231+
232+
root, err := os.OpenRoot(dir)
233+
if err != nil {
234+
return nil, err
228235
}
236+
defer root.Close()
229237

230-
fp := filepath.Join(dir, fn)
231-
f, err := os.OpenFile(fp, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, os.FileMode(perm))
238+
f, err := root.OpenFile(fn, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, os.FileMode(perm))
232239
if err != nil {
233240
return nil, err
234241
}
@@ -257,13 +264,13 @@ func (p *puller) Snapshot(ctx context.Context, jobCtx solver.JobContext) (ir cac
257264
}
258265
}
259266
if gid != 0 || uid != 0 {
260-
if err := os.Chown(fp, uid, gid); err != nil {
267+
if err := root.Chown(fn, uid, gid); err != nil {
261268
return nil, err
262269
}
263270
}
264271

265272
mTime := time.Unix(0, 0)
266-
if err := os.Chtimes(fp, mTime, mTime); err != nil {
273+
if err := root.Chtimes(fn, mTime, mTime); err != nil {
267274
return nil, err
268275
}
269276

source/containerblob/source.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ import (
1313
"github.com/moby/buildkit/solver/pb"
1414
"github.com/moby/buildkit/source"
1515
srctypes "github.com/moby/buildkit/source/types"
16+
"github.com/moby/buildkit/source/util/pathutil"
1617
"github.com/pkg/errors"
1718
)
1819

@@ -86,7 +87,7 @@ func (is *Source) parseIdentifierAttrs(id *ImageBlobIdentifier, attrs map[string
8687
for k, v := range attrs {
8788
switch k {
8889
case pb.AttrHTTPFilename:
89-
id.Filename = v
90+
id.Filename = pathutil.SafeFileName(v)
9091
case pb.AttrHTTPPerm:
9192
i, err := strconv.ParseInt(v, 0, 64)
9293
if err != nil {

source/http/source.go

Lines changed: 5 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@ import (
1818
"strconv"
1919
"strings"
2020
"time"
21-
"unicode"
2221

2322
"github.com/moby/buildkit/cache"
2423
"github.com/moby/buildkit/session"
@@ -28,6 +27,7 @@ import (
2827
"github.com/moby/buildkit/solver/pb"
2928
"github.com/moby/buildkit/source"
3029
srctypes "github.com/moby/buildkit/source/types"
30+
"github.com/moby/buildkit/source/util/pathutil"
3131
"github.com/moby/buildkit/util/bklog"
3232
"github.com/moby/buildkit/util/cachedigest"
3333
"github.com/moby/buildkit/util/pgpsign"
@@ -963,30 +963,16 @@ func (hs *httpSourceHandler) newHTTPRequest(ctx context.Context, g session.Group
963963
return req, nil
964964
}
965965

966-
func safeFileName(s string) string {
967-
defaultName := "download"
968-
name := filepath.Base(filepath.FromSlash(strings.TrimSpace(s)))
969-
if name == "" || name == "." || name == ".." {
970-
return defaultName
971-
}
972-
for _, r := range name {
973-
if r == 0 || unicode.IsControl(r) {
974-
return defaultName
975-
}
976-
}
977-
return name
978-
}
979-
980966
func getFileName(urlStr, manualFilename string, resp *http.Response) string {
981967
if manualFilename != "" {
982-
return safeFileName(manualFilename)
968+
return pathutil.SafeFileName(manualFilename)
983969
}
984970
if resp != nil {
985971
if contentDisposition := resp.Header.Get("Content-Disposition"); contentDisposition != "" {
986972
if _, params, err := mime.ParseMediaType(contentDisposition); err == nil {
987973
if params["filename"] != "" && !strings.HasSuffix(params["filename"], "/") {
988974
if filename := filepath.Base(filepath.FromSlash(params["filename"])); filename != "" {
989-
return safeFileName(filename)
975+
return pathutil.SafeFileName(filename)
990976
}
991977
}
992978
}
@@ -995,10 +981,10 @@ func getFileName(urlStr, manualFilename string, resp *http.Response) string {
995981
u, err := url.Parse(urlStr)
996982
if err == nil {
997983
if base := path.Base(u.Path); base != "." && base != "/" {
998-
return safeFileName(base)
984+
return pathutil.SafeFileName(base)
999985
}
1000986
}
1001-
return safeFileName("")
987+
return pathutil.SafeFileName("")
1002988
}
1003989

1004990
func searchHTTPURLDigest(ctx context.Context, store cache.MetadataStore, dgst digest.Digest) ([]cacheRefMetadata, error) {

source/http/source_test.go

Lines changed: 0 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@ import (
44
"context"
55
"os"
66
"path/filepath"
7-
"runtime"
87
"testing"
98

109
"github.com/containerd/containerd/v2/core/diff/apply"
@@ -32,50 +31,6 @@ import (
3231

3332
const signFixturesPathEnv = "BUILDKIT_TEST_SIGN_FIXTURES"
3433

35-
func TestSafeFileName(t *testing.T) {
36-
t.Parallel()
37-
38-
type testCase struct {
39-
name string
40-
in string
41-
want string
42-
}
43-
44-
tests := []testCase{
45-
{name: "simple", in: "foo", want: "foo"},
46-
{name: "simple_ext", in: "foo.txt", want: "foo.txt"},
47-
{name: "unicode_cjk", in: "資料.txt", want: "資料.txt"},
48-
{name: "unicode_cyrillic", in: "тест-файл", want: "тест-файл"},
49-
{name: "spaces_allowed", in: "name with spaces.txt", want: "name with spaces.txt"},
50-
{name: "trim_outer_whitespace", in: " foo.txt ", want: "foo.txt"},
51-
{name: "unix_path", in: "a/b/c.txt", want: "c.txt"},
52-
{name: "empty", in: "", want: "download"},
53-
{name: "dot", in: ".", want: "download"},
54-
{name: "dot_dot", in: "..", want: "download"},
55-
{name: "traversal_unix", in: "../", want: "download"},
56-
{name: "nul_byte", in: "a\x00b", want: "download"},
57-
{name: "control", in: "a\nb", want: "download"},
58-
}
59-
if runtime.GOOS == "windows" {
60-
tests = append(tests,
61-
testCase{name: "windows_traversal", in: "..\\", want: "download"},
62-
testCase{name: "windows_path_basename", in: "a\\b\\c.txt", want: "c.txt"},
63-
)
64-
} else {
65-
tests = append(tests,
66-
testCase{name: "windows_traversal_literal", in: "..\\", want: "..\\"},
67-
testCase{name: "windows_path_literal", in: "a\\b\\c.txt", want: "a\\b\\c.txt"},
68-
)
69-
}
70-
71-
for _, tt := range tests {
72-
t.Run(tt.name, func(t *testing.T) {
73-
t.Parallel()
74-
require.Equal(t, tt.want, safeFileName(tt.in))
75-
})
76-
}
77-
}
78-
7934
func TestHTTPSource(t *testing.T) {
8035
t.Parallel()
8136
ctx := context.TODO()

source/util/pathutil/pathutil.go

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
package pathutil
2+
3+
import (
4+
"path/filepath"
5+
"strings"
6+
"unicode"
7+
)
8+
9+
func SafeFileName(s string) string {
10+
defaultName := "download"
11+
name := filepath.Base(filepath.FromSlash(strings.TrimSpace(s)))
12+
if name == "" || name == "." || name == ".." {
13+
return defaultName
14+
}
15+
for _, r := range name {
16+
if r == 0 || unicode.IsControl(r) {
17+
return defaultName
18+
}
19+
}
20+
return name
21+
}
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
package pathutil
2+
3+
import (
4+
"runtime"
5+
"testing"
6+
7+
"github.com/stretchr/testify/require"
8+
)
9+
10+
func TestSafeFileName(t *testing.T) {
11+
t.Parallel()
12+
13+
type testCase struct {
14+
name string
15+
in string
16+
want string
17+
}
18+
19+
tests := []testCase{
20+
{name: "simple", in: "foo", want: "foo"},
21+
{name: "simple_ext", in: "foo.txt", want: "foo.txt"},
22+
{name: "unicode_cjk", in: "資料.txt", want: "資料.txt"},
23+
{name: "unicode_cyrillic", in: "тест-файл", want: "тест-файл"},
24+
{name: "spaces_allowed", in: "name with spaces.txt", want: "name with spaces.txt"},
25+
{name: "trim_outer_whitespace", in: " foo.txt ", want: "foo.txt"},
26+
{name: "unix_path", in: "a/b/c.txt", want: "c.txt"},
27+
{name: "empty", in: "", want: "download"},
28+
{name: "dot", in: ".", want: "download"},
29+
{name: "dot_dot", in: "..", want: "download"},
30+
{name: "traversal_unix", in: "../", want: "download"},
31+
{name: "nul_byte", in: "a\x00b", want: "download"},
32+
{name: "control", in: "a\nb", want: "download"},
33+
}
34+
if runtime.GOOS == "windows" {
35+
tests = append(tests,
36+
testCase{name: "windows_traversal", in: "..\\", want: "download"},
37+
testCase{name: "windows_path_basename", in: "a\\b\\c.txt", want: "c.txt"},
38+
)
39+
} else {
40+
tests = append(tests,
41+
testCase{name: "windows_traversal_literal", in: "..\\", want: "..\\"},
42+
testCase{name: "windows_path_literal", in: "a\\b\\c.txt", want: "a\\b\\c.txt"},
43+
)
44+
}
45+
46+
for _, tt := range tests {
47+
t.Run(tt.name, func(t *testing.T) {
48+
t.Parallel()
49+
require.Equal(t, tt.want, SafeFileName(tt.in))
50+
})
51+
}
52+
}

0 commit comments

Comments
 (0)