diff --git a/chat_stream.go b/chat_stream.go index 9378c7124..625d436cb 100644 --- a/chat_stream.go +++ b/chat_stream.go @@ -3,7 +3,6 @@ package openai import ( "bufio" "context" - "net/http" utils "github.com/sashabaranov/go-openai/internal" ) @@ -57,7 +56,7 @@ func (c *Client) CreateChatCompletionStream( if err != nil { return } - if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusBadRequest { + if isFailureStatusCode(resp) { return nil, c.handleErrorResp(resp) } diff --git a/client.go b/client.go index 2486e36b6..f38c1dfc3 100644 --- a/client.go +++ b/client.go @@ -47,13 +47,6 @@ func NewOrgClient(authToken, org string) *Client { func (c *Client) sendRequest(req *http.Request, v any) error { req.Header.Set("Accept", "application/json; charset=utf-8") - // Azure API Key authentication - if c.config.APIType == APITypeAzure { - req.Header.Set(AzureAPIKeyHeader, c.config.authToken) - } else { - // OpenAI or Azure AD authentication - req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.config.authToken)) - } // Check whether Content-Type is already set, Upload Files API requires // Content-Type == multipart/form-data @@ -62,9 +55,7 @@ func (c *Client) sendRequest(req *http.Request, v any) error { req.Header.Set("Content-Type", "application/json; charset=utf-8") } - if len(c.config.OrgID) > 0 { - req.Header.Set("OpenAI-Organization", c.config.OrgID) - } + c.setCommonHeaders(req) res, err := c.config.HTTPClient.Do(req) if err != nil { @@ -73,13 +64,31 @@ func (c *Client) sendRequest(req *http.Request, v any) error { defer res.Body.Close() - if res.StatusCode < http.StatusOK || res.StatusCode >= http.StatusBadRequest { + if isFailureStatusCode(res) { return c.handleErrorResp(res) } return decodeResponse(res.Body, v) } +func (c *Client) setCommonHeaders(req *http.Request) { + // https://learn.microsoft.com/en-us/azure/cognitive-services/openai/reference#authentication + // Azure API Key authentication + if c.config.APIType == APITypeAzure { + req.Header.Set(AzureAPIKeyHeader, c.config.authToken) + } else { + // OpenAI or Azure AD authentication + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.config.authToken)) + } + if c.config.OrgID != "" { + req.Header.Set("OpenAI-Organization", c.config.OrgID) + } +} + +func isFailureStatusCode(resp *http.Response) bool { + return resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusBadRequest +} + func decodeResponse(body io.Reader, v any) error { if v == nil { return nil @@ -145,17 +154,7 @@ func (c *Client) newStreamRequest( req.Header.Set("Cache-Control", "no-cache") req.Header.Set("Connection", "keep-alive") - // https://learn.microsoft.com/en-us/azure/cognitive-services/openai/reference#authentication - // Azure API Key authentication - if c.config.APIType == APITypeAzure { - req.Header.Set(AzureAPIKeyHeader, c.config.authToken) - } else { - // OpenAI or Azure AD authentication - req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.config.authToken)) - } - if c.config.OrgID != "" { - req.Header.Set("OpenAI-Organization", c.config.OrgID) - } + c.setCommonHeaders(req) return req, nil } diff --git a/client_test.go b/client_test.go index 5e63539df..81ed33259 100644 --- a/client_test.go +++ b/client_test.go @@ -224,6 +224,9 @@ func TestClientReturnsRequestBuilderErrors(t *testing.T) { {"GetFile", func() (any, error) { return client.GetFile(ctx, "") }}, + {"GetFileContent", func() (any, error) { + return client.GetFileContent(ctx, "") + }}, {"ListFiles", func() (any, error) { return client.ListFiles(ctx) }}, diff --git a/files.go b/files.go index 36c024365..fb9937bea 100644 --- a/files.go +++ b/files.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "fmt" + "io" "net/http" "os" ) @@ -103,3 +104,26 @@ func (c *Client) GetFile(ctx context.Context, fileID string) (file File, err err err = c.sendRequest(req, &file) return } + +func (c *Client) GetFileContent(ctx context.Context, fileID string) (content io.ReadCloser, err error) { + urlSuffix := fmt.Sprintf("/files/%s/content", fileID) + req, err := c.requestBuilder.Build(ctx, http.MethodGet, c.fullURL(urlSuffix), nil) + if err != nil { + return + } + + c.setCommonHeaders(req) + + res, err := c.config.HTTPClient.Do(req) + if err != nil { + return + } + + if isFailureStatusCode(res) { + err = c.handleErrorResp(res) + return + } + + content = res.Body + return +} diff --git a/files_test.go b/files_test.go index ffdcfa798..8e8934935 100644 --- a/files_test.go +++ b/files_test.go @@ -7,6 +7,7 @@ import ( "context" "encoding/json" + "errors" "fmt" "io" "net/http" @@ -141,3 +142,168 @@ func TestFileUploadWithNonExistentPath(t *testing.T) { _, err := client.CreateFile(ctx, req) checks.ErrorIs(t, err, os.ErrNotExist, "CreateFile should return error if file does not exist") } + +func TestDeleteFile(t *testing.T) { + server := test.NewTestServer() + server.RegisterHandler("/v1/files/deadbeef", func(w http.ResponseWriter, r *http.Request) { + + }) + // create the test server + var err error + ts := server.OpenAITestServer() + ts.Start() + defer ts.Close() + + config := DefaultConfig(test.GetTestToken()) + config.BaseURL = ts.URL + "/v1" + client := NewClientWithConfig(config) + ctx := context.Background() + + err = client.DeleteFile(ctx, "deadbeef") + checks.NoError(t, err, "DeleteFile error") +} + +func TestListFile(t *testing.T) { + server := test.NewTestServer() + server.RegisterHandler("/v1/files", func(w http.ResponseWriter, r *http.Request) { + fmt.Fprint(w, "{}") + }) + // create the test server + var err error + ts := server.OpenAITestServer() + ts.Start() + defer ts.Close() + + config := DefaultConfig(test.GetTestToken()) + config.BaseURL = ts.URL + "/v1" + client := NewClientWithConfig(config) + ctx := context.Background() + + _, err = client.ListFiles(ctx) + checks.NoError(t, err, "ListFiles error") +} + +func TestGetFile(t *testing.T) { + server := test.NewTestServer() + server.RegisterHandler("/v1/files/deadbeef", func(w http.ResponseWriter, r *http.Request) { + fmt.Fprint(w, "{}") + }) + // create the test server + var err error + ts := server.OpenAITestServer() + ts.Start() + defer ts.Close() + + config := DefaultConfig(test.GetTestToken()) + config.BaseURL = ts.URL + "/v1" + client := NewClientWithConfig(config) + ctx := context.Background() + + _, err = client.GetFile(ctx, "deadbeef") + checks.NoError(t, err, "GetFile error") +} + +func TestGetFileContent(t *testing.T) { + wantRespJsonl := `{"prompt": "foo", "completion": "foo"} +{"prompt": "bar", "completion": "bar"} +{"prompt": "baz", "completion": "baz"} +` + server := test.NewTestServer() + server.RegisterHandler("/v1/files/deadbeef/content", func(w http.ResponseWriter, r *http.Request) { + // edits only accepts GET requests + if r.Method != http.MethodGet { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + } + fmt.Fprint(w, wantRespJsonl) + }) + // create the test server + var err error + ts := server.OpenAITestServer() + ts.Start() + defer ts.Close() + + config := DefaultConfig(test.GetTestToken()) + config.BaseURL = ts.URL + "/v1" + client := NewClientWithConfig(config) + ctx := context.Background() + + content, err := client.GetFileContent(ctx, "deadbeef") + checks.NoError(t, err, "GetFileContent error") + defer content.Close() + + actual, _ := io.ReadAll(content) + if string(actual) != wantRespJsonl { + t.Errorf("Expected %s, got %s", wantRespJsonl, string(actual)) + } +} + +func TestGetFileContentReturnError(t *testing.T) { + wantMessage := "To help mitigate abuse, downloading of fine-tune training files is disabled for free accounts." + wantType := "invalid_request_error" + wantErrorResp := `{ + "error": { + "message": "` + wantMessage + `", + "type": "` + wantType + `", + "param": null, + "code": null + } +}` + server := test.NewTestServer() + server.RegisterHandler("/v1/files/deadbeef/content", func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusBadRequest) + fmt.Fprint(w, wantErrorResp) + }) + // create the test server + ts := server.OpenAITestServer() + ts.Start() + defer ts.Close() + + config := DefaultConfig(test.GetTestToken()) + config.BaseURL = ts.URL + "/v1" + client := NewClientWithConfig(config) + ctx := context.Background() + + _, err := client.GetFileContent(ctx, "deadbeef") + if err == nil { + t.Fatal("Did not return error") + } + + apiErr := &APIError{} + if !errors.As(err, &apiErr) { + t.Fatalf("Did not return APIError: %+v\n", apiErr) + } + if apiErr.Message != wantMessage { + t.Fatalf("Expected %s Message, got = %s\n", wantMessage, apiErr.Message) + return + } + if apiErr.Type != wantType { + t.Fatalf("Expected %s Type, got = %s\n", wantType, apiErr.Type) + return + } +} + +func TestGetFileContentReturnTimeoutError(t *testing.T) { + server := test.NewTestServer() + server.RegisterHandler("/v1/files/deadbeef/content", func(w http.ResponseWriter, r *http.Request) { + time.Sleep(10 * time.Nanosecond) + }) + // create the test server + ts := server.OpenAITestServer() + ts.Start() + defer ts.Close() + + config := DefaultConfig(test.GetTestToken()) + config.BaseURL = ts.URL + "/v1" + client := NewClientWithConfig(config) + ctx := context.Background() + ctx, cancel := context.WithTimeout(ctx, time.Nanosecond) + defer cancel() + + _, err := client.GetFileContent(ctx, "deadbeef") + if err == nil { + t.Fatal("Did not return error") + } + if !os.IsTimeout(err) { + t.Fatal("Did not return timeout error") + } +} diff --git a/stream.go b/stream.go index d4e352314..94cc0a0a2 100644 --- a/stream.go +++ b/stream.go @@ -4,7 +4,6 @@ import ( "bufio" "context" "errors" - "net/http" utils "github.com/sashabaranov/go-openai/internal" ) @@ -46,7 +45,7 @@ func (c *Client) CreateCompletionStream( if err != nil { return } - if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusBadRequest { + if isFailureStatusCode(resp) { return nil, c.handleErrorResp(resp) }