Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 52 additions & 4 deletions internal/cmdopts/cmdsource.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,23 @@
"context"
"errors"
"fmt"
"net/url"

"github.com/cybertec-postgresql/pgwatch/v3/internal/sources"
)

type SourceCommand struct {
owner *Options
Ping SourcePingCommand `command:"ping" description:"Try to connect to configured sources, report errors if any and then exit"`
owner *Options
Ping SourcePingCommand `command:"ping" description:"Try to connect to configured sources, report errors if any and then exit"`
Resolve SourceResolveCommand `command:"resolve" description:"Resolve monitored connections for a given sources (all by default)"`
// PrintSQL SourcePrintCommand `command:"print" description:"Get and print SQL for a given Source"`
}

func NewSourceCommand(owner *Options) *SourceCommand {
return &SourceCommand{
owner: owner,
Ping: SourcePingCommand{owner: owner},
owner: owner,
Ping: SourcePingCommand{owner: owner},
Resolve: SourceResolveCommand{owner: owner},
}
}

Expand Down Expand Up @@ -70,3 +73,48 @@
cmd.owner.CompleteCommand(map[bool]int32{true: ExitCodeCmdError, false: ExitCodeOK}[err != nil])
return nil
}

type SourceResolveCommand struct {
owner *Options
}

func (cmd *SourceResolveCommand) Execute(args []string) error {
err := cmd.owner.InitSourceReader(context.Background())
if err != nil {
return err
}
srcs, err := cmd.owner.SourcesReaderWriter.GetSources()
if err != nil {
return err
}
var foundSources sources.Sources
if len(args) == 0 {
foundSources = srcs
} else {
for _, name := range args {
for _, s := range srcs {
if s.Name == name {
foundSources = append(foundSources, s)
}
}
}
}
conns, err := foundSources.ResolveDatabases()
if err != nil {
return err
}
var connstr url.URL
connstr.Scheme = "postgresql"
for _, s := range conns {
if s.ConnStr > "" {
fmt.Printf("%s=%s\n", s.Name, s.ConnStr)
} else {
connstr.Host = fmt.Sprintf("%s:%d", s.ConnConfig.ConnConfig.Host, s.ConnConfig.ConnConfig.Port)
connstr.User = url.UserPassword(s.ConnConfig.ConnConfig.User, s.ConnConfig.ConnConfig.Password)
connstr.Path = s.ConnConfig.ConnConfig.Database
fmt.Printf("%s=%s\n", s.Name, connstr.String())
}
}
cmd.owner.CompleteCommand(ExitCodeOK)
return nil
}
72 changes: 72 additions & 0 deletions internal/cmdopts/cmdsource_test.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
package cmdopts

import (
"context"
"os"
"testing"

"github.com/cybertec-postgresql/pgwatch/v3/internal/db"
"github.com/cybertec-postgresql/pgwatch/v3/internal/sources"
"github.com/jackc/pgx/v5/pgxpool"
"github.com/pashagolub/pgxmock/v4"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
Expand All @@ -17,10 +22,26 @@ func TestSourcePingCommand_Execute(t *testing.T) {
_, err = f.WriteString(`
- name: test1
kind: postgres
is_enabled: true
conn_str: postgresql://foo@bar/baz
- name: test2
kind: postgres-continuous-discovery
is_enabled: true
conn_str: postgresql://foo@bar/baz
- name: test3
kind: patroni-namespace-discovery
is_enabled: true
conn_str: postgresql://foo@bar/baz`)

require.NoError(t, err)

sources.NewConn = func(_ context.Context, _ string, _ ...db.ConnConfigCallback) (db.PgxPoolIface, error) {
return nil, assert.AnError
}
sources.NewConnWithConfig = func(_ context.Context, _ *pgxpool.Config, _ ...db.ConnConfigCallback) (db.PgxPoolIface, error) {
return nil, assert.AnError
}

os.Args = []string{0: "config_test", "--sources=" + f.Name(), "source", "ping"}
_, err = New(nil)
assert.NoError(t, err)
Expand All @@ -29,3 +50,54 @@ func TestSourcePingCommand_Execute(t *testing.T) {
_, err = New(nil)
assert.NoError(t, err)
}

func TestSourceResolveCommand_Execute(t *testing.T) {
f, err := os.CreateTemp(t.TempDir(), "sample.config.yaml")
require.NoError(t, err)
defer f.Close()

_, err = f.WriteString(`
- name: test0
kind: postgres
is_enabled: true
conn_str: postgresql://foo@bar/baz
- name: test1
kind: postgres-continuous-discovery
is_enabled: true
conn_str: postgresql://foo@bar/baz`)

require.NoError(t, err)

mock, err := pgxmock.NewPool()
require.NoError(t, err)
sources.NewConn = func(_ context.Context, _ string, _ ...db.ConnConfigCallback) (db.PgxPoolIface, error) {
return mock, nil
}

t.Run("ResolveSuccess", func(t *testing.T) {
r := func(args []string) {
mock.ExpectQuery("select.+datname").
WithArgs(pgxmock.AnyArg(), pgxmock.AnyArg()).
WillReturnRows(
pgxmock.NewRows([]string{"datname"}).
AddRow("foo"))
os.Args = args
_, err = New(nil)
assert.NoError(t, err)
assert.NoError(t, mock.ExpectationsWereMet())
}
r([]string{0: "config_test", "--sources=" + f.Name(), "source", "resolve"})
r([]string{0: "config_test", "--sources=" + f.Name(), "source", "resolve", "test1"})
})

t.Run("ResolveError", func(t *testing.T) {
mock.ExpectQuery("select.+datname").
WithArgs(pgxmock.AnyArg(), pgxmock.AnyArg()).
WillReturnError(assert.AnError)
os.Args = []string{0: "config_test", "--sources=" + f.Name(), "source", "resolve"}
_, err = New(nil)
assert.ErrorIs(t, err, assert.AnError)
assert.NoError(t, mock.ExpectationsWereMet())
})

}
11 changes: 7 additions & 4 deletions internal/sources/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,12 @@ import (
"github.com/jackc/pgx/v5/pgxpool"
)

// NewConn and NewConnWithConfig are wrappers to allow testing
var (
NewConn = db.New
NewConnWithConfig = db.NewWithConfig
)

// SourceConn represents a single connection to monitor. Unlike source, it contains a database connection.
// Continuous discovery sources (postgres-continuous-discovery, patroni-continuous-discovery, patroni-namespace-discovery)
// will produce multiple monitored databases structs based on the discovered databases.
Expand All @@ -32,9 +38,6 @@ func (md *SourceConn) Ping(ctx context.Context) (err error) {
return md.Conn.Ping(ctx)
}

// NewWithConfig is a function that creates a new connection pool with the given config.
var NewWithConfig = db.NewWithConfig

// Connect will establish a connection to the database if it's not already connected.
// If the connection is already established, it pings the server to ensure it's still alive.
func (md *SourceConn) Connect(ctx context.Context, opts CmdOpts) (err error) {
Expand All @@ -48,7 +51,7 @@ func (md *SourceConn) Connect(ctx context.Context, opts CmdOpts) (err error) {
if opts.MaxParallelConnectionsPerDb > 0 {
md.ConnConfig.MaxConns = int32(opts.MaxParallelConnectionsPerDb)
}
md.Conn, err = NewWithConfig(ctx, md.ConnConfig)
md.Conn, err = NewConnWithConfig(ctx, md.ConnConfig)
if err != nil {
return err
}
Expand Down
4 changes: 2 additions & 2 deletions internal/sources/conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ func TestSourceConn_Connect(t *testing.T) {

t.Run("failed connection", func(t *testing.T) {
md := &sources.SourceConn{}
sources.NewWithConfig = func(_ context.Context, _ *pgxpool.Config, _ ...db.ConnConfigCallback) (db.PgxPoolIface, error) {
sources.NewConnWithConfig = func(_ context.Context, _ *pgxpool.Config, _ ...db.ConnConfigCallback) (db.PgxPoolIface, error) {
return nil, assert.AnError
}
err := md.Connect(ctx, sources.CmdOpts{})
Expand All @@ -37,7 +37,7 @@ func TestSourceConn_Connect(t *testing.T) {
t.Run("successful connection to pgbouncer", func(t *testing.T) {
mock, err := pgxmock.NewPool()
require.NoError(t, err)
sources.NewWithConfig = func(_ context.Context, _ *pgxpool.Config, _ ...db.ConnConfigCallback) (db.PgxPoolIface, error) {
sources.NewConnWithConfig = func(_ context.Context, _ *pgxpool.Config, _ ...db.ConnConfigCallback) (db.PgxPoolIface, error) {
return mock, nil
}

Expand Down
4 changes: 2 additions & 2 deletions internal/sources/resolver.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ func (srcs Sources) ResolveDatabases() (_ SourceConns, err error) {
err = errors.Join(err, e)
resolvedDbs = append(resolvedDbs, dbs...)
}
return resolvedDbs, nil
return resolvedDbs, err
}

// ResolveDatabases() return a slice of found databases for continuous monitoring sources, e.g. patroni
Expand Down Expand Up @@ -352,7 +352,7 @@ func ResolveDatabasesFromPostgres(s Source) (resolvedDbs SourceConns, err error)
dbname string
rows pgx.Rows
)
c, err = db.New(context.TODO(), s.ConnStr)
c, err = NewConn(context.TODO(), s.ConnStr)
if err != nil {
return
}
Expand Down
Loading