@@ -23,6 +23,7 @@ import (
2323 "github.com/stretchr/testify/assert"
2424 "github.com/stretchr/testify/require"
2525
26+ "vitess.io/vitess/go/fileutil"
2627 tabletmanagerdatapb "vitess.io/vitess/go/vt/proto/tabletmanagerdata"
2728)
2829
@@ -83,7 +84,7 @@ func TestFileEntryFullPath(t *testing.T) {
8384 name string
8485 entry FileEntry
8586 wantPath string
86- wantError string
87+ wantError error
8788 }{
8889 {
8990 name : "valid relative path in DataDir" ,
@@ -113,41 +114,43 @@ func TestFileEntryFullPath(t *testing.T) {
113114 {
114115 name : "path traversal escapes base directory" ,
115116 entry : FileEntry {Base : backupData , Name : "../../etc/passwd" },
116- wantError : "path traversal not allowed" ,
117+ wantError : fileutil . ErrInvalidJoinedPath ,
117118 },
118119 {
119120 name : "path traversal with deeper nesting" ,
120121 entry : FileEntry {Base : backupData , Name : "mydb/../../../etc/shadow" },
121- wantError : "path traversal not allowed" ,
122+ wantError : fileutil . ErrInvalidJoinedPath ,
122123 },
123124 {
124125 name : "path traversal to root" ,
125126 entry : FileEntry {Base : backupData , Name : "../../../../../etc/crontab" },
126- wantError : "path traversal not allowed" ,
127+ wantError : fileutil . ErrInvalidJoinedPath ,
127128 },
128129 {
129130 name : "path traversal escapes ParentPath" ,
130131 entry : FileEntry {Base : backupData , Name : "../../../../etc/passwd" , ParentPath : "/tmp/restore" },
131- wantError : "path traversal not allowed" ,
132+ wantError : fileutil . ErrInvalidJoinedPath ,
132133 },
133134 {
134135 name : "relative path with dot-dot that stays within base" ,
135136 entry : FileEntry {Base : backupData , Name : "mydb/../mydb/table1.ibd" },
136137 wantPath : "/vt/data/mydb/table1.ibd" ,
137138 },
138- {
139- name : "unknown base" ,
140- entry : FileEntry {Base : "unknown" , Name : "file" },
141- wantError : "unknown base" ,
142- },
143139 }
144140
141+ // Test unknown base separately since it returns a different error type.
142+ t .Run ("unknown base" , func (t * testing.T ) {
143+ entry := FileEntry {Base : "unknown" , Name : "file" }
144+ _ , err := entry .fullPath (cnf )
145+ require .Error (t , err )
146+ assert .Contains (t , err .Error (), "unknown base" )
147+ })
148+
145149 for _ , tt := range tests {
146150 t .Run (tt .name , func (t * testing.T ) {
147151 got , err := tt .entry .fullPath (cnf )
148- if tt .wantError != "" {
149- require .Error (t , err )
150- assert .Contains (t , err .Error (), tt .wantError )
152+ if tt .wantError != nil {
153+ require .ErrorIs (t , err , tt .wantError )
151154 } else {
152155 require .NoError (t , err )
153156 assert .Equal (t , tt .wantPath , got )
0 commit comments