Skip to content

Commit 0a071b5

Browse files
committed
Add support for more finegrained restriction on who is allowed to open tunnels
1 parent db96428 commit 0a071b5

File tree

7 files changed

+198
-62
lines changed

7 files changed

+198
-62
lines changed

Makefile

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ check: .check-fmt .check-vet .check-lint .check-ineffassign .check-static .check
5151

5252
.PHONY: .check-vendor
5353
.check-vendor:
54-
@dep ensure -no-vendor -dry-run
54+
@go mod vendor
5555

5656
.PHONY: test
5757
test:
@@ -61,12 +61,11 @@ test:
6161
.PHONY: get-deps
6262
get-deps:
6363
@echo "==> Installing dependencies..."
64-
@dep ensure
64+
@go mod init
6565

6666
.PHONY: get-tools
6767
get-tools:
6868
@echo "==> Installing tools..."
69-
@go get -u github.com/golang/dep/cmd/dep
7069
@go get -u golang.org/x/lint/golint
7170
@go get -u github.com/golang/mock/gomock
7271

client.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,7 @@ func (c *Client) dial() (net.Conn, error) {
177177
if err == nil {
178178
c.logger.Log(
179179
"level", 1,
180-
"msg", fmt.Sprintf("Setting up keep alive using config: %v", c.config.KeepAlive.String()),
180+
"msg", fmt.Sprintf("setting up keep alive using config: %v", c.config.KeepAlive.String()),
181181
)
182182
err = c.config.KeepAlive.Set(conn)
183183
}

cmd/tunneld/tunneld.go

Lines changed: 25 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,14 @@ import (
1212
"net"
1313
"net/http"
1414
"os"
15-
"strings"
1615
"time"
1716

1817
"golang.org/x/net/http2"
1918

2019
"github.com/bep/debounce"
2120
tunnel "github.com/hons82/go-http-tunnel"
2221
"github.com/hons82/go-http-tunnel/connection"
23-
"github.com/hons82/go-http-tunnel/id"
22+
"github.com/hons82/go-http-tunnel/fileutil"
2423
"github.com/hons82/go-http-tunnel/log"
2524
)
2625

@@ -71,16 +70,15 @@ func main() {
7170
}
7271

7372
if !autoSubscribe {
74-
for _, c := range strings.Split(opts.clients, ",") {
75-
if c == "" {
76-
fatal("empty client id")
77-
}
78-
identifier := id.ID{}
79-
err := identifier.UnmarshalText([]byte(c))
80-
if err != nil {
81-
fatal("invalid identifier %q: %s", c, err)
73+
clients, err := fileutil.ReadPropertiesFile(opts.clients)
74+
if err != nil {
75+
fatal("failed to load clients: %s", err)
76+
}
77+
78+
for host, value := range clients {
79+
if err := server.RegisterTunnel(host, value); err != nil {
80+
fatal("failed to load tunnel: %s with error %s", host, err)
8281
}
83-
server.Subscribe(identifier)
8482
}
8583
}
8684

@@ -103,6 +101,9 @@ func main() {
103101
// start HTTP
104102
if opts.httpAddr != "" {
105103
go func() {
104+
s := &http.Server{
105+
Addr: opts.httpAddr,
106+
}
106107
if opts.httpsAddr != "" {
107108
logger.Log(
108109
"level", 1,
@@ -114,29 +115,27 @@ func main() {
114115
if err != nil {
115116
fatal("failed to get https port: %s", err)
116117
}
117-
fatal("failed to start HTTP: %s",
118-
http.ListenAndServe(opts.httpAddr, http.HandlerFunc(
119-
func(w http.ResponseWriter, r *http.Request) {
120-
host, _, err := net.SplitHostPort(r.Host)
121-
if err != nil {
122-
host = r.Host
123-
}
124-
u := r.URL
125-
u.Host = net.JoinHostPort(host, tlsPort)
126-
u.Scheme = "https"
127-
http.Redirect(w, r, u.String(), http.StatusMovedPermanently)
128-
},
129-
)),
118+
s.Handler = http.HandlerFunc(
119+
func(w http.ResponseWriter, r *http.Request) {
120+
host, _, err := net.SplitHostPort(r.Host)
121+
if err != nil {
122+
host = r.Host
123+
}
124+
u := r.URL
125+
u.Host = net.JoinHostPort(host, tlsPort)
126+
u.Scheme = "https"
127+
http.Redirect(w, r, u.String(), http.StatusMovedPermanently)
128+
},
130129
)
131130
} else {
132131
logger.Log(
133132
"level", 1,
134133
"action", "start http",
135134
"addr", opts.httpAddr,
136135
)
137-
138-
fatal("failed to start HTTP: %s", http.ListenAndServe(opts.httpAddr, server))
136+
s.Handler = server
139137
}
138+
fatal("failed to start HTTP: %s", s.ListenAndServe())
140139
}()
141140
}
142141

fileutil/file.go

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
package fileutil
2+
3+
import (
4+
"bufio"
5+
"log"
6+
"os"
7+
"strings"
8+
)
9+
10+
type AppConfigProperties map[string]string
11+
12+
func ReadPropertiesFile(filename string) (AppConfigProperties, error) {
13+
config := AppConfigProperties{}
14+
15+
if len(filename) == 0 {
16+
return config, nil
17+
}
18+
file, err := os.Open(filename)
19+
if err != nil {
20+
log.Fatal(err)
21+
return nil, err
22+
}
23+
defer file.Close()
24+
25+
scanner := bufio.NewScanner(file)
26+
for scanner.Scan() {
27+
line := scanner.Text()
28+
if equal := strings.Index(line, "="); equal >= 0 {
29+
if key := strings.TrimSpace(line[:equal]); len(key) > 0 {
30+
value := ""
31+
if len(line) > equal {
32+
value = strings.TrimSpace(line[equal+1:])
33+
}
34+
config[key] = value
35+
}
36+
}
37+
}
38+
39+
if err := scanner.Err(); err != nil {
40+
log.Fatal(err)
41+
return nil, err
42+
}
43+
44+
return config, nil
45+
}

id/ptls.go

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,34 +6,49 @@ package id
66

77
import (
88
"crypto/tls"
9+
"crypto/x509"
910
"fmt"
1011
)
1112

13+
type IDInfo struct {
14+
Client string
15+
}
16+
1217
var emptyID [32]byte
18+
var emptyIDInfo = &IDInfo{}
1319

1420
// PeerID is modified https://github.com/andrew-d/ptls/blob/b89c7dcc94630a77f225a48befd3710144c7c10e/ptls.go#L81
15-
func PeerID(conn *tls.Conn) (ID, error) {
21+
func PeerID(conn *tls.Conn) (ID, IDInfo, error) {
1622
// Try a TLS connection over the given connection. We explicitly perform
1723
// the handshake, since we want to maintain the invariant that, if this
1824
// function returns successfully, then the connection should be valid
1925
// and verified.
2026
if err := conn.Handshake(); err != nil {
21-
return emptyID, err
27+
return emptyID, *emptyIDInfo, err
2228
}
2329

2430
cs := conn.ConnectionState()
2531

2632
// We should have exactly one peer certificate.
2733
certs := cs.PeerCertificates
2834
if cl := len(certs); cl != 1 {
29-
return emptyID, ImproperCertsNumberError{cl}
35+
return emptyID, *emptyIDInfo, ImproperCertsNumberError{cl}
3036
}
3137

3238
// Get remote cert's ID.
3339
remoteCert := certs[0]
34-
remoteID := New(remoteCert.Raw)
40+
remoteID := New(remoteID(*remoteCert))
41+
remoteIDInfo := &IDInfo{
42+
Client: remoteCert.Issuer.SerialNumber,
43+
}
44+
return remoteID, *remoteIDInfo, nil
45+
}
3546

36-
return remoteID, nil
47+
func remoteID(c x509.Certificate) []byte {
48+
if c.Issuer.SerialNumber != "" {
49+
return []byte(c.Issuer.SerialNumber)
50+
}
51+
return c.Raw
3752
}
3853

3954
// ImproperCertsNumberError is returned from Server/Client whenever the remote

registry.go

Lines changed: 41 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ import (
1616
// RegistryItem holds information about hosts and listeners associated with a
1717
// client.
1818
type RegistryItem struct {
19+
*id.IDInfo
1920
Hosts []*HostAuth
2021
Listeners []net.Listener
2122
}
@@ -27,6 +28,7 @@ type HostAuth struct {
2728
}
2829

2930
type hostInfo struct {
31+
*id.IDInfo
3032
identifier id.ID
3133
auth *Auth
3234
}
@@ -91,6 +93,15 @@ func (r *registry) Subscriber(hostPort string) (id.ID, *Auth, bool) {
9193
return h.identifier, h.auth, ok
9294
}
9395

96+
func (r *registry) HasTunnel(hostPort string, identifier id.ID) bool {
97+
r.mu.RLock()
98+
defer r.mu.RUnlock()
99+
100+
h, ok := r.hosts[trimPort(hostPort)]
101+
102+
return ok && h.identifier.Equals(identifier)
103+
}
104+
94105
// Unsubscribe removes client from registry and returns it's RegistryItem.
95106
func (r *registry) Unsubscribe(identifier id.ID) *RegistryItem {
96107
r.mu.Lock()
@@ -141,7 +152,7 @@ func (r *registry) set(i *RegistryItem, identifier id.ID) error {
141152
if h.Auth != nil && h.Auth.User == "" {
142153
return fmt.Errorf("missing auth user")
143154
}
144-
if _, ok := r.hosts[trimPort(h.Host)]; ok {
155+
if hi, ok := r.hosts[trimPort(h.Host)]; ok && !hi.identifier.Equals(identifier) {
145156
return fmt.Errorf("host %q is occupied", h.Host)
146157
}
147158
}
@@ -159,6 +170,35 @@ func (r *registry) set(i *RegistryItem, identifier id.ID) error {
159170
return nil
160171
}
161172

173+
func (r *registry) RegisterTunnel(host string, client string) error {
174+
identifier := id.New([]byte(client))
175+
176+
r.logger.Log(
177+
"level", 2,
178+
"action", "add tunnel",
179+
"host", host,
180+
"identifier", identifier,
181+
)
182+
183+
r.Subscribe(identifier)
184+
185+
r.mu.Lock()
186+
defer r.mu.Unlock()
187+
188+
if _, ok := r.hosts[trimPort(host)]; ok {
189+
return fmt.Errorf("host %q is occupied", host)
190+
}
191+
192+
r.hosts[trimPort(host)] = &hostInfo{
193+
identifier: identifier,
194+
IDInfo: &id.IDInfo{
195+
Client: client,
196+
},
197+
}
198+
199+
return nil
200+
}
201+
162202
func (r *registry) clear(identifier id.ID) *RegistryItem {
163203
r.logger.Log(
164204
"level", 2,
@@ -174,12 +214,6 @@ func (r *registry) clear(identifier id.ID) *RegistryItem {
174214
return nil
175215
}
176216

177-
if i.Hosts != nil {
178-
for _, h := range i.Hosts {
179-
delete(r.hosts, trimPort(h.Host))
180-
}
181-
}
182-
183217
r.items[identifier] = voidRegistryItem
184218

185219
return i

0 commit comments

Comments
 (0)