Skip to content

Support Retrieve file content API (#347) #348

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions chat_stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package openai
import (
"bufio"
"context"
"net/http"

utils "github.com./sashabaranov/go-openai/internal"
)
Expand Down Expand Up @@ -57,7 +56,7 @@ func (c *Client) CreateChatCompletionStream(
if err != nil {
return
}
if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusBadRequest {
if isFailureStatusCode(resp) {
return nil, c.handleErrorResp(resp)
}

Expand Down
43 changes: 21 additions & 22 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,6 @@ func NewOrgClient(authToken, org string) *Client {

func (c *Client) sendRequest(req *http.Request, v any) error {
req.Header.Set("Accept", "application/json; charset=utf-8")
// Azure API Key authentication
if c.config.APIType == APITypeAzure {
req.Header.Set(AzureAPIKeyHeader, c.config.authToken)
} else {
// OpenAI or Azure AD authentication
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.config.authToken))
}

// Check whether Content-Type is already set, Upload Files API requires
// Content-Type == multipart/form-data
Expand All @@ -62,9 +55,7 @@ func (c *Client) sendRequest(req *http.Request, v any) error {
req.Header.Set("Content-Type", "application/json; charset=utf-8")
}

if len(c.config.OrgID) > 0 {
req.Header.Set("OpenAI-Organization", c.config.OrgID)
}
c.setCommonHeaders(req)

res, err := c.config.HTTPClient.Do(req)
if err != nil {
Expand All @@ -73,13 +64,31 @@ func (c *Client) sendRequest(req *http.Request, v any) error {

defer res.Body.Close()

if res.StatusCode < http.StatusOK || res.StatusCode >= http.StatusBadRequest {
if isFailureStatusCode(res) {
return c.handleErrorResp(res)
}

return decodeResponse(res.Body, v)
}

func (c *Client) setCommonHeaders(req *http.Request) {
// https://learn.microsoft.com/en-us/azure/cognitive-services/openai/reference#authentication
// Azure API Key authentication
if c.config.APIType == APITypeAzure {
req.Header.Set(AzureAPIKeyHeader, c.config.authToken)
} else {
// OpenAI or Azure AD authentication
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.config.authToken))
}
if c.config.OrgID != "" {
req.Header.Set("OpenAI-Organization", c.config.OrgID)
}
}

func isFailureStatusCode(resp *http.Response) bool {
return resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusBadRequest
}

func decodeResponse(body io.Reader, v any) error {
if v == nil {
return nil
Expand Down Expand Up @@ -145,17 +154,7 @@ func (c *Client) newStreamRequest(
req.Header.Set("Cache-Control", "no-cache")
req.Header.Set("Connection", "keep-alive")

// https://learn.microsoft.com/en-us/azure/cognitive-services/openai/reference#authentication
// Azure API Key authentication
if c.config.APIType == APITypeAzure {
req.Header.Set(AzureAPIKeyHeader, c.config.authToken)
} else {
// OpenAI or Azure AD authentication
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.config.authToken))
}
if c.config.OrgID != "" {
req.Header.Set("OpenAI-Organization", c.config.OrgID)
}
c.setCommonHeaders(req)
return req, nil
}

Expand Down
3 changes: 3 additions & 0 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,9 @@ func TestClientReturnsRequestBuilderErrors(t *testing.T) {
{"GetFile", func() (any, error) {
return client.GetFile(ctx, "")
}},
{"GetFileContent", func() (any, error) {
return client.GetFileContent(ctx, "")
}},
{"ListFiles", func() (any, error) {
return client.ListFiles(ctx)
}},
Expand Down
24 changes: 24 additions & 0 deletions files.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"bytes"
"context"
"fmt"
"io"
"net/http"
"os"
)
Expand Down Expand Up @@ -103,3 +104,26 @@ func (c *Client) GetFile(ctx context.Context, fileID string) (file File, err err
err = c.sendRequest(req, &file)
return
}

func (c *Client) GetFileContent(ctx context.Context, fileID string) (content io.ReadCloser, err error) {
urlSuffix := fmt.Sprintf("/files/%s/content", fileID)
req, err := c.requestBuilder.Build(ctx, http.MethodGet, c.fullURL(urlSuffix), nil)
if err != nil {
return
}

c.setCommonHeaders(req)

res, err := c.config.HTTPClient.Do(req)
if err != nil {
return
}

if isFailureStatusCode(res) {
err = c.handleErrorResp(res)
return
}

content = res.Body
return
}
166 changes: 166 additions & 0 deletions files_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (

"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
Expand Down Expand Up @@ -141,3 +142,168 @@ func TestFileUploadWithNonExistentPath(t *testing.T) {
_, err := client.CreateFile(ctx, req)
checks.ErrorIs(t, err, os.ErrNotExist, "CreateFile should return error if file does not exist")
}

func TestDeleteFile(t *testing.T) {
server := test.NewTestServer()
server.RegisterHandler("/v1/files/deadbeef", func(w http.ResponseWriter, r *http.Request) {

})
// create the test server
var err error
ts := server.OpenAITestServer()
ts.Start()
defer ts.Close()

config := DefaultConfig(test.GetTestToken())
config.BaseURL = ts.URL + "/v1"
client := NewClientWithConfig(config)
ctx := context.Background()

err = client.DeleteFile(ctx, "deadbeef")
checks.NoError(t, err, "DeleteFile error")
}

func TestListFile(t *testing.T) {
server := test.NewTestServer()
server.RegisterHandler("/v1/files", func(w http.ResponseWriter, r *http.Request) {
fmt.Fprint(w, "{}")
})
// create the test server
var err error
ts := server.OpenAITestServer()
ts.Start()
defer ts.Close()

config := DefaultConfig(test.GetTestToken())
config.BaseURL = ts.URL + "/v1"
client := NewClientWithConfig(config)
ctx := context.Background()

_, err = client.ListFiles(ctx)
checks.NoError(t, err, "ListFiles error")
}

func TestGetFile(t *testing.T) {
server := test.NewTestServer()
server.RegisterHandler("/v1/files/deadbeef", func(w http.ResponseWriter, r *http.Request) {
fmt.Fprint(w, "{}")
})
// create the test server
var err error
ts := server.OpenAITestServer()
ts.Start()
defer ts.Close()

config := DefaultConfig(test.GetTestToken())
config.BaseURL = ts.URL + "/v1"
client := NewClientWithConfig(config)
ctx := context.Background()

_, err = client.GetFile(ctx, "deadbeef")
checks.NoError(t, err, "GetFile error")
}

func TestGetFileContent(t *testing.T) {
wantRespJsonl := `{"prompt": "foo", "completion": "foo"}
{"prompt": "bar", "completion": "bar"}
{"prompt": "baz", "completion": "baz"}
`
server := test.NewTestServer()
server.RegisterHandler("/v1/files/deadbeef/content", func(w http.ResponseWriter, r *http.Request) {
// edits only accepts GET requests
if r.Method != http.MethodGet {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
}
fmt.Fprint(w, wantRespJsonl)
})
// create the test server
var err error
ts := server.OpenAITestServer()
ts.Start()
defer ts.Close()

config := DefaultConfig(test.GetTestToken())
config.BaseURL = ts.URL + "/v1"
client := NewClientWithConfig(config)
ctx := context.Background()

content, err := client.GetFileContent(ctx, "deadbeef")
checks.NoError(t, err, "GetFileContent error")
defer content.Close()

actual, _ := io.ReadAll(content)
if string(actual) != wantRespJsonl {
t.Errorf("Expected %s, got %s", wantRespJsonl, string(actual))
}
}

func TestGetFileContentReturnError(t *testing.T) {
wantMessage := "To help mitigate abuse, downloading of fine-tune training files is disabled for free accounts."
wantType := "invalid_request_error"
wantErrorResp := `{
"error": {
"message": "` + wantMessage + `",
"type": "` + wantType + `",
"param": null,
"code": null
}
}`
server := test.NewTestServer()
server.RegisterHandler("/v1/files/deadbeef/content", func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusBadRequest)
fmt.Fprint(w, wantErrorResp)
})
// create the test server
ts := server.OpenAITestServer()
ts.Start()
defer ts.Close()

config := DefaultConfig(test.GetTestToken())
config.BaseURL = ts.URL + "/v1"
client := NewClientWithConfig(config)
ctx := context.Background()

_, err := client.GetFileContent(ctx, "deadbeef")
if err == nil {
t.Fatal("Did not return error")
}

apiErr := &APIError{}
if !errors.As(err, &apiErr) {
t.Fatalf("Did not return APIError: %+v\n", apiErr)
}
if apiErr.Message != wantMessage {
t.Fatalf("Expected %s Message, got = %s\n", wantMessage, apiErr.Message)
return
}
if apiErr.Type != wantType {
t.Fatalf("Expected %s Type, got = %s\n", wantType, apiErr.Type)
return
}
}

func TestGetFileContentReturnTimeoutError(t *testing.T) {
server := test.NewTestServer()
server.RegisterHandler("/v1/files/deadbeef/content", func(w http.ResponseWriter, r *http.Request) {
time.Sleep(10 * time.Nanosecond)
})
// create the test server
ts := server.OpenAITestServer()
ts.Start()
defer ts.Close()

config := DefaultConfig(test.GetTestToken())
config.BaseURL = ts.URL + "/v1"
client := NewClientWithConfig(config)
ctx := context.Background()
ctx, cancel := context.WithTimeout(ctx, time.Nanosecond)
defer cancel()

_, err := client.GetFileContent(ctx, "deadbeef")
if err == nil {
t.Fatal("Did not return error")
}
if !os.IsTimeout(err) {
t.Fatal("Did not return timeout error")
}
}
3 changes: 1 addition & 2 deletions stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"bufio"
"context"
"errors"
"net/http"

utils "github.com./sashabaranov/go-openai/internal"
)
Expand Down Expand Up @@ -46,7 +45,7 @@ func (c *Client) CreateCompletionStream(
if err != nil {
return
}
if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusBadRequest {
if isFailureStatusCode(resp) {
return nil, c.handleErrorResp(resp)
}

Expand Down