@@ -3,6 +3,7 @@ package ssh
3
3
4
4
import (
5
5
"fmt"
6
+ "reflect"
6
7
7
8
"gopkg.in/src-d/go-git.v4/plumbing/transport"
8
9
"gopkg.in/src-d/go-git.v4/plumbing/transport/internal/common"
@@ -11,7 +12,12 @@ import (
11
12
)
12
13
13
14
// DefaultClient is the default SSH client.
14
- var DefaultClient = common .NewClient (& runner {})
15
+ var DefaultClient = NewClient (nil )
16
+
17
+ // NewClient creates a new SSH client with an optional *ssh.ClientConfig.
18
+ func NewClient (config * ssh.ClientConfig ) transport.Transport {
19
+ return common .NewClient (& runner {config : config })
20
+ }
15
21
16
22
// DefaultAuthBuilder is the function used to create a default AuthMethod, when
17
23
// the user doesn't provide any.
@@ -21,10 +27,12 @@ var DefaultAuthBuilder = func(user string) (AuthMethod, error) {
21
27
22
28
const DefaultPort = 22
23
29
24
- type runner struct {}
30
+ type runner struct {
31
+ config * ssh.ClientConfig
32
+ }
25
33
26
34
func (r * runner ) Command (cmd string , ep transport.Endpoint , auth transport.AuthMethod ) (common.Command , error ) {
27
- c := & command {command : cmd , endpoint : ep }
35
+ c := & command {command : cmd , endpoint : ep , config : r . config }
28
36
if auth != nil {
29
37
c .setAuth (auth )
30
38
}
@@ -42,6 +50,7 @@ type command struct {
42
50
endpoint transport.Endpoint
43
51
client * ssh.Client
44
52
auth AuthMethod
53
+ config * ssh.ClientConfig
45
54
}
46
55
47
56
func (c * command ) setAuth (auth transport.AuthMethod ) error {
@@ -95,6 +104,8 @@ func (c *command) connect() error {
95
104
return err
96
105
}
97
106
107
+ overrideConfig (c .config , config )
108
+
98
109
c .client , err = ssh .Dial ("tcp" , c .getHostWithPort (), config )
99
110
if err != nil {
100
111
return err
@@ -129,3 +140,25 @@ func (c *command) setAuthFromEndpoint() error {
129
140
func endpointToCommand (cmd string , ep transport.Endpoint ) string {
130
141
return fmt .Sprintf ("%s '%s'" , cmd , ep .Path ())
131
142
}
143
+
144
+ func overrideConfig (overrides * ssh.ClientConfig , c * ssh.ClientConfig ) {
145
+ if overrides == nil {
146
+ return
147
+ }
148
+
149
+ vo := reflect .ValueOf (* overrides )
150
+ vc := reflect .ValueOf (* c )
151
+ for i := 0 ; i < vc .Type ().NumField (); i ++ {
152
+ vcf := vc .Field (i )
153
+ vof := vo .Field (i )
154
+ if isZeroValue (vcf ) {
155
+ vcf .Set (vof )
156
+ }
157
+ }
158
+
159
+ * c = vc .Interface ().(ssh.ClientConfig )
160
+ }
161
+
162
+ func isZeroValue (v reflect.Value ) bool {
163
+ return reflect .DeepEqual (v .Interface (), reflect .Zero (v .Type ()).Interface ())
164
+ }
0 commit comments