Skip to content
Merged
24 changes: 19 additions & 5 deletions api_internal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,13 +112,15 @@ func TestAzureFullURL(t *testing.T) {
Name string
BaseURL string
AzureModelMapper map[string]string
Suffix string
Model string
Expect string
}{
{
"AzureBaseURLWithSlashAutoStrip",
"https://httpbin.org/",
nil,
"/chat/completions",
"chatgpt-demo",
"https://httpbin.org/" +
"openai/deployments/chatgpt-demo" +
Expand All @@ -128,19 +130,28 @@ func TestAzureFullURL(t *testing.T) {
"AzureBaseURLWithoutSlashOK",
"https://httpbin.org",
nil,
"/chat/completions",
"chatgpt-demo",
"https://httpbin.org/" +
"openai/deployments/chatgpt-demo" +
"/chat/completions?api-version=2023-05-15",
},
{
"",
"https://httpbin.org",
nil,
"/assistants?limit=10",
"chatgpt-demo",
"https://httpbin.org/openai/assistants?api-version=2023-05-15&limit=10",
},
}

for _, c := range cases {
t.Run(c.Name, func(t *testing.T) {
az := DefaultAzureConfig("dummy", c.BaseURL)
cli := NewClientWithConfig(az)
// /openai/deployments/{engine}/chat/completions?api-version={api_version}
actual := cli.fullURL("/chat/completions", c.Model)
actual := cli.fullURL(c.Suffix, withModel(c.Model))
if actual != c.Expect {
t.Errorf("Expected %s, got %s", c.Expect, actual)
}
Expand All @@ -153,19 +164,22 @@ func TestCloudflareAzureFullURL(t *testing.T) {
cases := []struct {
Name string
BaseURL string
Suffix string
Expect string
}{
{
"CloudflareAzureBaseURLWithSlashAutoStrip",
"https://gateway.ai.cloudflare.com/v1/dnekeim2i39dmm4mldemakiem3i4mkw3/demo/azure-openai/resource/chatgpt-demo/",
"/chat/completions",
"https://gateway.ai.cloudflare.com/v1/dnekeim2i39dmm4mldemakiem3i4mkw3/demo/azure-openai/resource/chatgpt-demo/" +
"chat/completions?api-version=2023-05-15",
},
{
"CloudflareAzureBaseURLWithoutSlashOK",
"",
"https://gateway.ai.cloudflare.com/v1/dnekeim2i39dmm4mldemakiem3i4mkw3/demo/azure-openai/resource/chatgpt-demo",
"https://gateway.ai.cloudflare.com/v1/dnekeim2i39dmm4mldemakiem3i4mkw3/demo/azure-openai/resource/chatgpt-demo/" +
"chat/completions?api-version=2023-05-15",
"/assistants?limit=10",
"https://gateway.ai.cloudflare.com/v1/dnekeim2i39dmm4mldemakiem3i4mkw3/demo/azure-openai/resource/chatgpt-demo" +
"/assistants?api-version=2023-05-15&limit=10",
},
}

Expand All @@ -176,7 +190,7 @@ func TestCloudflareAzureFullURL(t *testing.T) {

cli := NewClientWithConfig(az)

actual := cli.fullURL("/chat/completions")
actual := cli.fullURL(c.Suffix)
if actual != c.Expect {
t.Errorf("Expected %s, got %s", c.Expect, actual)
}
Expand Down
9 changes: 7 additions & 2 deletions audio.go
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,13 @@ func (c *Client) callAudioAPI(
}

urlSuffix := fmt.Sprintf("/audio/%s", endpointSuffix)
req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix, request.Model),
withBody(&formBody), withContentType(builder.FormDataContentType()))
req, err := c.newRequest(
ctx,
http.MethodPost,
c.fullURL(urlSuffix, withModel(request.Model)),
withBody(&formBody),
withContentType(builder.FormDataContentType()),
)
if err != nil {
return AudioResponse{}, err
}
Expand Down
7 changes: 6 additions & 1 deletion chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,12 @@ func (c *Client) CreateChatCompletion(
return
}

req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix, request.Model), withBody(request))
req, err := c.newRequest(
ctx,
http.MethodPost,
c.fullURL(urlSuffix, withModel(request.Model)),
withBody(request),
)
if err != nil {
return
}
Expand Down
7 changes: 6 additions & 1 deletion chat_stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,12 @@ func (c *Client) CreateChatCompletionStream(
}

request.Stream = true
req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix, request.Model), withBody(request))
req, err := c.newRequest(
ctx,
http.MethodPost,
c.fullURL(urlSuffix, withModel(request.Model)),
withBody(request),
)
if err != nil {
return nil, err
}
Expand Down
84 changes: 54 additions & 30 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -222,42 +222,66 @@ func decodeString(body io.Reader, output *string) error {
return nil
}

type fullURLOptions struct {
model string
}

type fullURLOption func(*fullURLOptions)

func withModel(model string) fullURLOption {
return func(args *fullURLOptions) {
args.model = model
}
}

var azureDeploymentsEndpoints = []string{
"/completions",
"/embeddings",
"/chat/completions",
"/audio/transcriptions",
"/audio/translations",
"/audio/speech",
"/images/generations",
}

// fullURL returns full URL for request.
// args[0] is model name, if API type is Azure, model name is required to get deployment name.
func (c *Client) fullURL(suffix string, args ...any) string {
// /openai/deployments/{model}/chat/completions?api-version={api_version}
func (c *Client) fullURL(suffix string, setters ...fullURLOption) string {
baseURL := strings.TrimRight(c.config.BaseURL, "/")
args := fullURLOptions{}
for _, setter := range setters {
setter(&args)
}

if c.config.APIType == APITypeAzure || c.config.APIType == APITypeAzureAD {
baseURL := c.config.BaseURL
baseURL = strings.TrimRight(baseURL, "/")
parseURL, _ := url.Parse(baseURL)
query := parseURL.Query()
query.Add("api-version", c.config.APIVersion)
// if suffix is /models change to {endpoint}/openai/models?api-version=2022-12-01
// https://learn.microsoft.com/en-us/rest/api/cognitiveservices/azureopenaistable/models/list?tabs=HTTP
if containsSubstr([]string{"/models", "/assistants", "/threads", "/files"}, suffix) {
return fmt.Sprintf("%s/%s%s?%s", baseURL, azureAPIPrefix, suffix, query.Encode())
}
azureDeploymentName := "UNKNOWN"
if len(args) > 0 {
model, ok := args[0].(string)
if ok {
azureDeploymentName = c.config.GetAzureDeploymentByModel(model)
}
}
return fmt.Sprintf("%s/%s/%s/%s%s?%s",
baseURL, azureAPIPrefix, azureDeploymentsPrefix,
azureDeploymentName, suffix, query.Encode(),
)
baseURL = c.baseURLWithAzureDeployment(baseURL, suffix, args.model)
}

if c.config.APIVersion != "" {
suffix = c.suffixWithAPIVersion(suffix)
}
return fmt.Sprintf("%s%s", baseURL, suffix)
}

// https://developers.cloudflare.com/ai-gateway/providers/azureopenai/
if c.config.APIType == APITypeCloudflareAzure {
baseURL := c.config.BaseURL
baseURL = strings.TrimRight(baseURL, "/")
return fmt.Sprintf("%s%s?api-version=%s", baseURL, suffix, c.config.APIVersion)
func (c *Client) suffixWithAPIVersion(suffix string) string {
parsedSuffix, err := url.Parse(suffix)
if err != nil {
panic("failed to parse url suffix")
}
query := parsedSuffix.Query()
query.Add("api-version", c.config.APIVersion)
return fmt.Sprintf("%s?%s", parsedSuffix.Path, query.Encode())
}

return fmt.Sprintf("%s%s", c.config.BaseURL, suffix)
func (c *Client) baseURLWithAzureDeployment(baseURL, suffix, model string) (newBaseURL string) {
baseURL = fmt.Sprintf("%s/%s", strings.TrimRight(baseURL, "/"), azureAPIPrefix)
if containsSubstr(azureDeploymentsEndpoints, suffix) {
azureDeploymentName := c.config.GetAzureDeploymentByModel(model)
if azureDeploymentName == "" {
azureDeploymentName = "UNKNOWN"
}
baseURL = fmt.Sprintf("%s/%s/%s", baseURL, azureDeploymentsPrefix, azureDeploymentName)
}
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This part is pretty hard to figure about 🤔 Could I suggest breaking it down into more functions and keeping fullURLOptions closer to the code that uses it?

Something like this:

func (c *Client) suffixWithAPIVersion(previousSuffix string) (newSuffix string) {
	parsedSuffix, err := url.Parse(previousSuffix)
	if err != nil {
		panic("failed to parse url suffix")
	}

	query := parsedSuffix.Query()
	query.Add("api-version", c.config.APIVersion)
	return fmt.Sprintf("%s?%s", parsedSuffix.Path, query.Encode())
}

func (c *Client) baseURLWithAzureDeployment(baseURL, suffix, model string) (newBaseURL string) {
	azureDeploymentName := c.config.GetAzureDeploymentByModel(model)
	if azureDeploymentName == "" {
		azureDeploymentName = "UNKNOWN"
	}
	baseURL = fmt.Sprintf("%s/%s", baseURL, azureAPIPrefix)
	if containsSubstr(azureDeploymentsEndpoints, suffix) {
		baseURL = fmt.Sprintf("%s/%s/%s", baseURL, azureDeploymentsPrefix, azureDeploymentName)
	}
	return baseURL
}


// fullURL returns full URL for request.
func (c *Client) fullURL(suffix string, setters ...fullURLOption) string {
	baseURL := strings.TrimRight(c.config.BaseURL, "/")
	urlOptions := fullURLOptions{}
	for _, setter := range setters {
		setter(&args)
	}

	if c.config.APIType == APITypeAzure || c.config.APIType == APITypeAzureAD {
		baseURL = baseURLWithAzureDeployment(baseURL, suffix, urlOptions.model)
	}

	if c.config.APIVersion != "" {
		suffix = c.suffixWithAPIVersion(suffix)
	}

	return fmt.Sprintf("%s%s", baseURL, suffix)
}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've completed the suggested changes. Could you please review it again?

return baseURL
}

func (c *Client) handleErrorResp(resp *http.Response) error {
Expand Down
96 changes: 96 additions & 0 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -431,3 +431,99 @@ func TestClientReturnsRequestBuilderErrorsAddition(t *testing.T) {
t.Fatalf("Did not return error when request builder failed: %v", err)
}
}

func TestClient_suffixWithAPIVersion(t *testing.T) {
type fields struct {
apiVersion string
}
type args struct {
suffix string
}
tests := []struct {
name string
fields fields
args args
want string
wantPanic string
}{
{
"",
fields{apiVersion: "2023-05"},
args{suffix: "/assistants"},
"/assistants?api-version=2023-05",
"",
},
{
"",
fields{apiVersion: "2023-05"},
args{suffix: "/assistants?limit=5"},
"/assistants?api-version=2023-05&limit=5",
"",
},
{
"",
fields{apiVersion: "2023-05"},
args{suffix: "123:assistants?limit=5"},
"/assistants?api-version=2023-05&limit=5",
"failed to parse url suffix",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := &Client{
config: ClientConfig{APIVersion: tt.fields.apiVersion},
}
defer func() {
if r := recover(); r != nil {
if r.(string) != tt.wantPanic {
t.Errorf("suffixWithAPIVersion() = %v, want %v", r, tt.wantPanic)
}
}
}()
if got := c.suffixWithAPIVersion(tt.args.suffix); got != tt.want {
t.Errorf("suffixWithAPIVersion() = %v, want %v", got, tt.want)
}
})
}
}

func TestClient_baseURLWithAzureDeployment(t *testing.T) {
type args struct {
baseURL string
suffix string
model string
}
tests := []struct {
name string
args args
wantNewBaseURL string
}{
{
"",
args{baseURL: "https://test.openai.azure.com/", suffix: assistantsSuffix, model: GPT4oMini},
"https://test.openai.azure.com/openai",
},
{
"",
args{baseURL: "https://test.openai.azure.com/", suffix: chatCompletionsSuffix, model: GPT4oMini},
"https://test.openai.azure.com/openai/deployments/gpt-4o-mini",
},
{
"",
args{baseURL: "https://test.openai.azure.com/", suffix: chatCompletionsSuffix, model: ""},
"https://test.openai.azure.com/openai/deployments/UNKNOWN",
},
}
client := NewClient("")
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if gotNewBaseURL := client.baseURLWithAzureDeployment(
tt.args.baseURL,
tt.args.suffix,
tt.args.model,
); gotNewBaseURL != tt.wantNewBaseURL {
t.Errorf("baseURLWithAzureDeployment() = %v, want %v", gotNewBaseURL, tt.wantNewBaseURL)
}
})
}
}
7 changes: 6 additions & 1 deletion completion.go
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,12 @@ func (c *Client) CreateCompletion(
return
}

req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix, request.Model), withBody(request))
req, err := c.newRequest(
ctx,
http.MethodPost,
c.fullURL(urlSuffix, withModel(request.Model)),
withBody(request),
)
if err != nil {
return
}
Expand Down
7 changes: 6 additions & 1 deletion edits.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,12 @@ will need to migrate to GPT-3.5 Turbo by January 4, 2024.
You can use CreateChatCompletion or CreateChatCompletionStream instead.
*/
func (c *Client) Edits(ctx context.Context, request EditsRequest) (response EditsResponse, err error) {
req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/edits", fmt.Sprint(request.Model)), withBody(request))
req, err := c.newRequest(
ctx,
http.MethodPost,
c.fullURL("/edits", withModel(fmt.Sprint(request.Model))),
withBody(request),
)
if err != nil {
return
}
Expand Down
7 changes: 6 additions & 1 deletion embeddings.go
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,12 @@ func (c *Client) CreateEmbeddings(
conv EmbeddingRequestConverter,
) (res EmbeddingResponse, err error) {
baseReq := conv.Convert()
req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/embeddings", string(baseReq.Model)), withBody(baseReq))
req, err := c.newRequest(
ctx,
http.MethodPost,
c.fullURL("/embeddings", withModel(string(baseReq.Model))),
withBody(baseReq),
)
if err != nil {
return
}
Expand Down
2 changes: 1 addition & 1 deletion example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ func ExampleClient_CreateChatCompletionStream() {
return
}

fmt.Printf(response.Choices[0].Delta.Content)
fmt.Println(response.Choices[0].Delta.Content)
}
}

Expand Down
Loading
Loading