Skip to content

Commit 123a133

Browse files
feat: [CI-19349]: Added oidc support for azure connector (#496)
* feat: [CI-19349]: Added oidc support for azure connector * feat: [CI-19349]: Added env variables * feat: [CI-19349]: Added tests * Update cmd/drone-acr/main.go * Update cmd/drone-acr/main.go * feat: [CI-19349]: Added Debug statements --------- Co-authored-by: OP (oppenheimer) <[email protected]>
1 parent 58bfad7 commit 123a133

File tree

4 files changed

+243
-9
lines changed

4 files changed

+243
-9
lines changed

cmd/drone-acr/main.go

Lines changed: 32 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ import (
2020
"github.com/sirupsen/logrus"
2121

2222
docker "github.com/drone-plugins/drone-docker"
23+
azureutil "github.com/drone-plugins/drone-docker/internal/azure"
2324
)
2425

2526
type subscriptionUrlResponse struct {
@@ -62,12 +63,14 @@ func main() {
6263
password = getenv("SERVICE_PRINCIPAL_CLIENT_SECRET")
6364

6465
// Service principal credentials
65-
clientId = getenv("CLIENT_ID")
66-
clientSecret = getenv("CLIENT_SECRET")
67-
clientCert = getenv("CLIENT_CERTIFICATE")
68-
tenantId = getenv("TENANT_ID")
69-
subscriptionId = getenv("SUBSCRIPTION_ID")
70-
publicUrl = getenv("DAEMON_REGISTRY")
66+
clientId = getenv("CLIENT_ID", "AZURE_CLIENT_ID", "AZURE_APP_ID", "PLUGIN_CLIENT_ID")
67+
clientSecret = getenv("CLIENT_SECRET", "PLUGIN_CLIENT_SECRET")
68+
clientCert = getenv("CLIENT_CERTIFICATE", "PLUGIN_CLIENT_CERTIFICATE")
69+
tenantId = getenv("TENANT_ID", "AZURE_TENANT_ID", "PLUGIN_TENANT_ID")
70+
subscriptionId = getenv("SUBSCRIPTION_ID", "PLUGIN_SUBSCRIPTION_ID")
71+
publicUrl = getenv("DAEMON_REGISTRY", "PLUGIN_DAEMON_REGISTRY")
72+
authorityHost = getenv("AZURE_AUTHORITY_HOST", "PLUGIN_AZURE_AUTHORITY_HOST")
73+
idToken = getenv("PLUGIN_OIDC_TOKEN_ID")
7174
)
7275

7376
// default registry value
@@ -80,9 +83,29 @@ func main() {
8083
// docker login credentials are not provided
8184
var err error
8285
username = defaultUsername
83-
password, publicUrl, err = getAuth(clientId, clientSecret, clientCert, tenantId, subscriptionId, registry)
84-
if err != nil {
85-
logrus.Fatal(err)
86+
if idToken != "" && clientId != "" && tenantId != "" {
87+
logrus.Debug("Using OIDC authentication flow")
88+
var aadToken string
89+
aadToken, err = azureutil.GetAADAccessTokenViaClientAssertion(context.Background(), tenantId, clientId, idToken, authorityHost)
90+
if err != nil {
91+
logrus.Fatal(err)
92+
}
93+
var p string
94+
p, err = getPublicUrl(aadToken, registry, subscriptionId)
95+
if err == nil {
96+
publicUrl = p
97+
} else {
98+
fmt.Fprintf(os.Stderr, "failed to get public url with error: %s\n", err)
99+
}
100+
password, err = fetchACRToken(tenantId, aadToken, registry)
101+
if err != nil {
102+
logrus.Fatal(err)
103+
}
104+
} else {
105+
password, publicUrl, err = getAuth(clientId, clientSecret, clientCert, tenantId, subscriptionId, registry)
106+
if err != nil {
107+
logrus.Fatal(err)
108+
}
86109
}
87110
}
88111

cmd/drone-acr/main_test.go

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
package main
2+
3+
import (
4+
"os"
5+
"testing"
6+
)
7+
8+
func TestGetAuthInputValidation(t *testing.T) {
9+
// missing tenant
10+
if _, _, err := getAuth("client", "secret", "", "", "sub", "registry.azurecr.io"); err == nil {
11+
t.Fatalf("expected error for missing tenantId")
12+
}
13+
// missing clientId
14+
if _, _, err := getAuth("", "secret", "", "tenant", "sub", "registry.azurecr.io"); err == nil {
15+
t.Fatalf("expected error for missing clientId")
16+
}
17+
// missing both secret and cert
18+
if _, _, err := getAuth("client", "", "", "tenant", "sub", "registry.azurecr.io"); err == nil {
19+
t.Fatalf("expected error for missing credentials")
20+
}
21+
}
22+
23+
func TestGetenvAuthorityHost(t *testing.T) {
24+
os.Setenv("AZURE_AUTHORITY_HOST", "https://login.microsoftonline.us")
25+
defer os.Unsetenv("AZURE_AUTHORITY_HOST")
26+
27+
got := getenv("AZURE_AUTHORITY_HOST")
28+
if got != "https://login.microsoftonline.us" {
29+
t.Fatalf("expected AZURE_AUTHORITY_HOST to be returned, got %q", got)
30+
}
31+
}
32+

internal/azure/tokenutil.go

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
package azure
2+
3+
import (
4+
"context"
5+
"encoding/json"
6+
"fmt"
7+
"io"
8+
"net/http"
9+
"net/url"
10+
"strings"
11+
"time"
12+
)
13+
14+
const DefaultResource = "https://management.azure.com/"
15+
const defaultAuthorityHost = "https://login.microsoftonline.com"
16+
const defaultHTTPTimeout = 30 * time.Second
17+
18+
// GetAADAccessTokenViaClientAssertion exchanges an external OIDC ID token for an Azure AD access token
19+
20+
func GetAADAccessTokenViaClientAssertion(ctx context.Context, tenantID, clientID, oidcToken, authorityHost string) (string, error) {
21+
resource := DefaultResource
22+
23+
form := url.Values{
24+
"client_id": {clientID},
25+
"scope": {resource + ".default"},
26+
"grant_type": {"client_credentials"},
27+
"client_assertion_type": {"urn:ietf:params:oauth:client-assertion-type:jwt-bearer"},
28+
"client_assertion": {oidcToken},
29+
}
30+
31+
base := authorityHost
32+
if strings.TrimSpace(base) == "" {
33+
base = defaultAuthorityHost
34+
}
35+
base = strings.TrimRight(base, "/")
36+
endpoint := fmt.Sprintf("%s/%s/oauth2/v2.0/token", base, tenantID)
37+
38+
client := &http.Client{Timeout: defaultHTTPTimeout}
39+
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, strings.NewReader(form.Encode()))
40+
if err != nil {
41+
return "", err
42+
}
43+
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
44+
req.Header.Set("Accept", "application/json")
45+
46+
resp, err := client.Do(req)
47+
if err != nil {
48+
return "", err
49+
}
50+
defer resp.Body.Close()
51+
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
52+
var aadErr struct {
53+
Error string `json:"error"`
54+
ErrorDescription string `json:"error_description"`
55+
}
56+
limited := io.LimitedReader{R: resp.Body, N: 4096}
57+
_ = json.NewDecoder(&limited).Decode(&aadErr)
58+
if aadErr.Error != "" {
59+
return "", fmt.Errorf("AAD token request failed: status=%d, error=%s", resp.StatusCode, aadErr.Error)
60+
}
61+
return "", fmt.Errorf("AAD token request failed: status=%d", resp.StatusCode)
62+
}
63+
var payload struct {
64+
AccessToken string `json:"access_token"`
65+
TokenType string `json:"token_type"`
66+
ExpiresIn int `json:"expires_in"`
67+
}
68+
if err := json.NewDecoder(resp.Body).Decode(&payload); err != nil {
69+
return "", err
70+
}
71+
if payload.AccessToken == "" {
72+
return "", fmt.Errorf("AAD token response missing access_token")
73+
}
74+
return payload.AccessToken, nil
75+
}

internal/azure/tokenutil_test.go

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
package azure
2+
3+
import (
4+
"context"
5+
"net/http"
6+
"net/http/httptest"
7+
"strings"
8+
"testing"
9+
)
10+
11+
func TestGetAADAccessTokenViaClientAssertion_Success(t *testing.T) {
12+
13+
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
14+
if r.Method != http.MethodPost {
15+
t.Fatalf("expected POST, got %s", r.Method)
16+
}
17+
if ct := r.Header.Get("Content-Type"); !strings.Contains(ct, "application/x-www-form-urlencoded") {
18+
t.Fatalf("expected form content-type, got %s", ct)
19+
}
20+
if err := r.ParseForm(); err != nil {
21+
t.Fatalf("failed parsing form: %v", err)
22+
}
23+
assertEq(t, r.Form.Get("client_id"), "client")
24+
assertEq(t, r.Form.Get("grant_type"), "client_credentials")
25+
assertEq(t, r.Form.Get("client_assertion_type"), "urn:ietf:params:oauth:client-assertion-type:jwt-bearer")
26+
assertEq(t, r.Form.Get("client_assertion"), "idtoken")
27+
assertEq(t, r.Form.Get("scope"), DefaultResource+".default")
28+
29+
w.Header().Set("Content-Type", "application/json")
30+
w.WriteHeader(http.StatusOK)
31+
_, _ = w.Write([]byte(`{"access_token":"AT","token_type":"Bearer","expires_in":3600}`))
32+
}))
33+
defer ts.Close()
34+
35+
tok, err := GetAADAccessTokenViaClientAssertion(context.Background(), "tenant", "client", "idtoken", ts.URL)
36+
if err != nil {
37+
t.Fatalf("unexpected error: %v", err)
38+
}
39+
if tok != "AT" {
40+
t.Fatalf("expected access token AT, got %q", tok)
41+
}
42+
}
43+
44+
func TestGetAADAccessTokenViaClientAssertion_400WithErrorField(t *testing.T) {
45+
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
46+
w.Header().Set("Content-Type", "application/json")
47+
w.WriteHeader(http.StatusBadRequest)
48+
_, _ = w.Write([]byte(`{"error":"invalid_client","error_description":"bad"}`))
49+
}))
50+
defer ts.Close()
51+
52+
_, err := GetAADAccessTokenViaClientAssertion(context.Background(), "tenant", "client", "idtoken", ts.URL)
53+
if err == nil || !strings.Contains(err.Error(), "status=400") || !strings.Contains(err.Error(), "invalid_client") {
54+
t.Fatalf("expected 400 with invalid_client error, got %v", err)
55+
}
56+
}
57+
58+
func TestGetAADAccessTokenViaClientAssertion_400WithoutErrorField(t *testing.T) {
59+
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
60+
w.WriteHeader(http.StatusBadRequest)
61+
_, _ = w.Write([]byte("{}"))
62+
}))
63+
defer ts.Close()
64+
65+
_, err := GetAADAccessTokenViaClientAssertion(context.Background(), "tenant", "client", "idtoken", ts.URL)
66+
if err == nil || !strings.Contains(err.Error(), "status=400") {
67+
t.Fatalf("expected 400 error, got %v", err)
68+
}
69+
}
70+
71+
func TestGetAADAccessTokenViaClientAssertion_MalformedJSON(t *testing.T) {
72+
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
73+
w.WriteHeader(http.StatusOK)
74+
_, _ = w.Write([]byte("not-json"))
75+
}))
76+
defer ts.Close()
77+
78+
_, err := GetAADAccessTokenViaClientAssertion(context.Background(), "tenant", "client", "idtoken", ts.URL)
79+
if err == nil {
80+
t.Fatalf("expected JSON decode error, got nil")
81+
}
82+
}
83+
84+
func TestGetAADAccessTokenViaClientAssertion_MissingAccessToken(t *testing.T) {
85+
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
86+
w.Header().Set("Content-Type", "application/json")
87+
w.WriteHeader(http.StatusOK)
88+
_, _ = w.Write([]byte(`{"token_type":"Bearer","expires_in":3600}`))
89+
}))
90+
defer ts.Close()
91+
92+
_, err := GetAADAccessTokenViaClientAssertion(context.Background(), "tenant", "client", "idtoken", ts.URL)
93+
if err == nil || !strings.Contains(err.Error(), "missing access_token") {
94+
t.Fatalf("expected missing access_token error, got %v", err)
95+
}
96+
}
97+
98+
func assertEq(t *testing.T, got, want string) {
99+
t.Helper()
100+
if got != want {
101+
t.Fatalf("mismatch: got=%q want=%q", got, want)
102+
}
103+
}
104+

0 commit comments

Comments
 (0)