Skip to content

Commit dd0f723

Browse files
vttablet: harden ExecuteHook RPC and backup engine flag inputs (#19486)
Signed-off-by: Tim Vaillancourt <tim@timvaillancourt.com>
1 parent 0400ddf commit dd0f723

5 files changed

Lines changed: 58 additions & 18 deletions

File tree

go/vt/hook/hook.go

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,12 @@ import (
2323
"io"
2424
"os"
2525
"os/exec"
26-
"path"
26+
"path/filepath"
2727
"strings"
2828
"syscall"
2929
"time"
3030

31+
"vitess.io/vitess/go/fileutil"
3132
vtenv "vitess.io/vitess/go/vt/env"
3233
"vitess.io/vitess/go/vt/log"
3334
)
@@ -98,19 +99,17 @@ func NewHookWithEnv(name string, params []string, env map[string]string) *Hook {
9899

99100
// findHook tries to locate the hook, and returns the exec.Cmd for it.
100101
func (hook *Hook) findHook(ctx context.Context) (*exec.Cmd, int, error) {
101-
// Check the hook path.
102-
if strings.Contains(hook.Name, "/") {
103-
return nil, HOOK_INVALID_NAME, errors.New("hook cannot contain '/'")
104-
}
105-
106102
// Find our root.
107103
root, err := vtenv.VtRoot()
108104
if err != nil {
109105
return nil, HOOK_VTROOT_ERROR, fmt.Errorf("cannot get VTROOT: %v", err)
110106
}
111107

112108
// See if the hook exists.
113-
vthook := path.Join(root, "vthook", hook.Name)
109+
vthook, err := fileutil.SafePathJoin(filepath.Join(root, "vthook"), hook.Name)
110+
if err != nil {
111+
return nil, HOOK_INVALID_NAME, fmt.Errorf("invalid hook name %q: %v", hook.Name, err)
112+
}
114113
_, err = os.Stat(vthook)
115114
if err != nil {
116115
if os.IsNotExist(err) {

go/vt/mysqlctl/mysqlshellbackupengine.go

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ package mysqlctl
1818

1919
import (
2020
"bufio"
21+
"bytes"
2122
"context"
2223
"encoding/json"
2324
"errors"
@@ -32,6 +33,7 @@ import (
3233
"sync"
3334
"time"
3435

36+
"github.com/google/shlex"
3537
"github.com/spf13/pflag"
3638

3739
"vitess.io/vitess/go/fileutil"
@@ -41,6 +43,7 @@ import (
4143
"vitess.io/vitess/go/vt/log"
4244
"vitess.io/vitess/go/vt/mysqlctl/backupstorage"
4345
tabletmanagerdatapb "vitess.io/vitess/go/vt/proto/tabletmanagerdata"
46+
vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc"
4447
"vitess.io/vitess/go/vt/servenv"
4548
"vitess.io/vitess/go/vt/vterrors"
4649
)
@@ -140,14 +143,20 @@ func (be *MySQLShellBackupEngine) ExecuteBackup(ctx context.Context, params Back
140143
return BackupUnusable, vterrors.Wrap(err, "can't get MySQL version")
141144
}
142145

143-
args := []string{}
144-
if mysqlShellFlags != "" {
145-
args = append(args, strings.Fields(mysqlShellFlags)...)
146+
args, err := shlex.Split(mysqlShellFlags)
147+
if err != nil {
148+
return BackupUnusable, vterrors.Wrap(err, "failed to parse --mysql-shell-flags")
149+
}
150+
151+
// compact and validate the json input from mysqlShellDumpFlags.
152+
var compactDumpFlags bytes.Buffer
153+
if err := json.Compact(&compactDumpFlags, []byte(mysqlShellDumpFlags)); err != nil {
154+
return BackupUnusable, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "failed to parse --mysql-shell-dump-flags as JSON: %v", err)
146155
}
147156

148157
args = append(args, "-e", fmt.Sprintf("util.dumpInstance(%q, %s)",
149158
location,
150-
mysqlShellDumpFlags,
159+
compactDumpFlags.String(),
151160
))
152161

153162
// to be able to get the consistent GTID sets, we will acquire a global read lock before starting mysql shell.
@@ -362,15 +371,20 @@ func (be *MySQLShellBackupEngine) ExecuteRestore(ctx context.Context, params Res
362371
}
363372
defer resetFunc()
364373

365-
args := []string{}
374+
args, err := shlex.Split(mysqlShellFlags)
375+
if err != nil {
376+
return nil, vterrors.Wrap(err, "failed to parse --mysql-shell-flags")
377+
}
366378

367-
if mysqlShellFlags != "" {
368-
args = append(args, strings.Fields(mysqlShellFlags)...)
379+
// compact and validate the json input from mysqlShellLoadFlags.
380+
var compactLoadFlags bytes.Buffer
381+
if err := json.Compact(&compactLoadFlags, []byte(mysqlShellLoadFlags)); err != nil {
382+
return nil, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "failed to parse --mysql-shell-load-flags as JSON: %v", err)
369383
}
370384

371385
args = append(args, "-e", fmt.Sprintf("util.loadDump(%q, %s)",
372386
location,
373-
mysqlShellLoadFlags,
387+
compactLoadFlags.String(),
374388
))
375389

376390
cmd := exec.CommandContext(ctx, "mysqlsh", args...)

go/vt/mysqlctl/xtrabackupengine.go

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ import (
3030
"sync"
3131
"time"
3232

33+
"github.com/google/shlex"
3334
"github.com/spf13/pflag"
3435

3536
"vitess.io/vitess/go/ioutil"
@@ -318,7 +319,11 @@ func (be *XtrabackupEngine) backupFiles(
318319
flagsToExec = append(flagsToExec, "--stream="+xtrabackupStreamMode)
319320
}
320321
if xtrabackupBackupFlags != "" {
321-
flagsToExec = append(flagsToExec, strings.Fields(xtrabackupBackupFlags)...)
322+
backupFlags, err := shlex.Split(xtrabackupBackupFlags)
323+
if err != nil {
324+
return replicationPosition, vterrors.Wrap(err, "failed to parse --xtrabackup-backup-flags")
325+
}
326+
flagsToExec = append(flagsToExec, backupFlags...)
322327
}
323328

324329
// Create a cancellable Context for calls to bh.AddFile().
@@ -545,7 +550,11 @@ func (be *XtrabackupEngine) restoreFromBackup(ctx context.Context, cnf *Mycnf, b
545550
"--target-dir=" + tempDir,
546551
}
547552
if xtrabackupPrepareFlags != "" {
548-
flagsToExec = append(flagsToExec, strings.Fields(xtrabackupPrepareFlags)...)
553+
prepareFlags, err := shlex.Split(xtrabackupPrepareFlags)
554+
if err != nil {
555+
return vterrors.Wrap(err, "failed to parse --xtrabackup-prepare-flags")
556+
}
557+
flagsToExec = append(flagsToExec, prepareFlags...)
549558
}
550559
prepareCmd := exec.CommandContext(ctx, restoreProgram, flagsToExec...)
551560
prepareOut, err := prepareCmd.StdoutPipe()
@@ -719,7 +728,11 @@ func (be *XtrabackupEngine) extractFiles(ctx context.Context, logger logutil.Log
719728
xbstreamProgram := path.Join(xtrabackupEnginePath, xbstream)
720729
flagsToExec := []string{"-C", tempDir, "-xv"}
721730
if xbstreamRestoreFlags != "" {
722-
flagsToExec = append(flagsToExec, strings.Fields(xbstreamRestoreFlags)...)
731+
restoreFlags, err := shlex.Split(xbstreamRestoreFlags)
732+
if err != nil {
733+
return vterrors.Wrap(err, "failed to parse --xbstream-restore-flags")
734+
}
735+
flagsToExec = append(flagsToExec, restoreFlags...)
723736
}
724737
xbstreamCmd := exec.CommandContext(ctx, xbstreamProgram, flagsToExec...)
725738
logger.Infof("Executing xbstream cmd: %v %v", xbstreamProgram, flagsToExec)

go/vt/vttablet/grpctmserver/server.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ package grpctmserver
1818

1919
import (
2020
"context"
21+
"path/filepath"
2122
"time"
2223

2324
"google.golang.org/grpc"
@@ -35,6 +36,7 @@ import (
3536
querypb "vitess.io/vitess/go/vt/proto/query"
3637
tabletmanagerdatapb "vitess.io/vitess/go/vt/proto/tabletmanagerdata"
3738
tabletmanagerservicepb "vitess.io/vitess/go/vt/proto/tabletmanagerservice"
39+
vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc"
3840
)
3941

4042
// server is the gRPC implementation of the RPC server
@@ -64,6 +66,9 @@ func (s *server) Sleep(ctx context.Context, request *tabletmanagerdatapb.SleepRe
6466
func (s *server) ExecuteHook(ctx context.Context, request *tabletmanagerdatapb.ExecuteHookRequest) (response *tabletmanagerdatapb.ExecuteHookResponse, err error) {
6567
defer s.tm.HandleRPCPanic(ctx, "ExecuteHook", request, response, true /*verbose*/, &err)
6668
ctx = callinfo.GRPCCallInfo(ctx)
69+
if request.Name == "" || filepath.Base(request.Name) != request.Name {
70+
return nil, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "hook name must be a basename, got %q", request.Name)
71+
}
6772
response = &tabletmanagerdatapb.ExecuteHookResponse{}
6873
hr := s.tm.ExecuteHook(ctx, &hook.Hook{
6974
Name: request.Name,

go/vt/vttablet/tmrpctest/test_tm_rpc.go

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ import (
2727
"testing"
2828
"time"
2929

30+
"github.com/stretchr/testify/assert"
3031
"github.com/stretchr/testify/require"
3132
"google.golang.org/protobuf/proto"
3233

@@ -594,6 +595,13 @@ func tmRPCTestExecuteHook(ctx context.Context, t *testing.T, client tmclient.Tab
594595
compareError(t, "ExecuteHook", err, hr, testExecuteHookHookResult)
595596
}
596597

598+
func tmRPCTestExecuteHookInvalidName(ctx context.Context, t *testing.T, client tmclient.TabletManagerClient, tablet *topodatapb.Tablet) {
599+
for _, name := range []string{"", "../etc/passwd", "/bin/ls"} {
600+
_, err := client.ExecuteHook(ctx, tablet, &hook.Hook{Name: name})
601+
assert.ErrorContains(t, err, "hook name must be a basename")
602+
}
603+
}
604+
597605
func tmRPCTestExecuteHookPanic(ctx context.Context, t *testing.T, client tmclient.TabletManagerClient, tablet *topodatapb.Tablet) {
598606
_, err := client.ExecuteHook(ctx, tablet, testExecuteHookHook)
599607
expectHandleRPCPanic(t, "ExecuteHook", true /*verbose*/, err)
@@ -1597,6 +1605,7 @@ func Run(t *testing.T, client tmclient.TabletManagerClient, tablet *topodatapb.T
15971605
tmRPCTestChangeType(ctx, t, client, tablet)
15981606
tmRPCTestSleep(ctx, t, client, tablet)
15991607
tmRPCTestExecuteHook(ctx, t, client, tablet)
1608+
tmRPCTestExecuteHookInvalidName(ctx, t, client, tablet)
16001609
tmRPCTestRefreshState(ctx, t, client, tablet)
16011610
tmRPCTestRunHealthCheck(ctx, t, client, tablet)
16021611
tmRPCTestReloadSchema(ctx, t, client, tablet)

0 commit comments

Comments
 (0)