Skip to content

Commit b616090

Browse files
authored
refactoring tests with mock servers (#30) (#356)
1 parent a243e73 commit b616090

20 files changed

+731
-1060
lines changed

api_test.go

+5-11
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@ import (
66
"errors"
77
"io"
88
"net/http"
9-
"net/http/httptest"
109
"os"
1110
"testing"
1211

@@ -226,18 +225,13 @@ func TestAPIErrorUnmarshalJSONInvalidMessage(t *testing.T) {
226225
}
227226

228227
func TestRequestError(t *testing.T) {
229-
var err error
230-
231-
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
228+
client, server, teardown := setupOpenAITestServer()
229+
defer teardown()
230+
server.RegisterHandler("/v1/engines", func(w http.ResponseWriter, r *http.Request) {
232231
w.WriteHeader(http.StatusTeapot)
233-
}))
234-
defer ts.Close()
232+
})
235233

236-
config := DefaultConfig("dummy")
237-
config.BaseURL = ts.URL
238-
c := NewClientWithConfig(config)
239-
ctx := context.Background()
240-
_, err = c.ListEngines(ctx)
234+
_, err := client.ListEngines(context.Background())
241235
checks.HasError(t, err, "ListEngines did not fail")
242236

243237
var reqErr *RequestError

audio_api_test.go

+162
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,162 @@
1+
package openai_test
2+
3+
import (
4+
"bytes"
5+
"context"
6+
"errors"
7+
"io"
8+
"mime"
9+
"mime/multipart"
10+
"net/http"
11+
"path/filepath"
12+
"strings"
13+
"testing"
14+
15+
. "github.com./sashabaranov/go-openai"
16+
"github.com./sashabaranov/go-openai/internal/test"
17+
"github.com./sashabaranov/go-openai/internal/test/checks"
18+
)
19+
20+
// TestAudio Tests the transcription and translation endpoints of the API using the mocked server.
21+
func TestAudio(t *testing.T) {
22+
client, server, teardown := setupOpenAITestServer()
23+
defer teardown()
24+
server.RegisterHandler("/v1/audio/transcriptions", handleAudioEndpoint)
25+
server.RegisterHandler("/v1/audio/translations", handleAudioEndpoint)
26+
27+
testcases := []struct {
28+
name string
29+
createFn func(context.Context, AudioRequest) (AudioResponse, error)
30+
}{
31+
{
32+
"transcribe",
33+
client.CreateTranscription,
34+
},
35+
{
36+
"translate",
37+
client.CreateTranslation,
38+
},
39+
}
40+
41+
ctx := context.Background()
42+
43+
dir, cleanup := test.CreateTestDirectory(t)
44+
defer cleanup()
45+
46+
for _, tc := range testcases {
47+
t.Run(tc.name, func(t *testing.T) {
48+
path := filepath.Join(dir, "fake.mp3")
49+
test.CreateTestFile(t, path)
50+
51+
req := AudioRequest{
52+
FilePath: path,
53+
Model: "whisper-3",
54+
}
55+
_, err := tc.createFn(ctx, req)
56+
checks.NoError(t, err, "audio API error")
57+
})
58+
59+
t.Run(tc.name+" (with reader)", func(t *testing.T) {
60+
req := AudioRequest{
61+
FilePath: "fake.webm",
62+
Reader: bytes.NewBuffer([]byte(`some webm binary data`)),
63+
Model: "whisper-3",
64+
}
65+
_, err := tc.createFn(ctx, req)
66+
checks.NoError(t, err, "audio API error")
67+
})
68+
}
69+
}
70+
71+
func TestAudioWithOptionalArgs(t *testing.T) {
72+
client, server, teardown := setupOpenAITestServer()
73+
defer teardown()
74+
server.RegisterHandler("/v1/audio/transcriptions", handleAudioEndpoint)
75+
server.RegisterHandler("/v1/audio/translations", handleAudioEndpoint)
76+
77+
testcases := []struct {
78+
name string
79+
createFn func(context.Context, AudioRequest) (AudioResponse, error)
80+
}{
81+
{
82+
"transcribe",
83+
client.CreateTranscription,
84+
},
85+
{
86+
"translate",
87+
client.CreateTranslation,
88+
},
89+
}
90+
91+
ctx := context.Background()
92+
93+
dir, cleanup := test.CreateTestDirectory(t)
94+
defer cleanup()
95+
96+
for _, tc := range testcases {
97+
t.Run(tc.name, func(t *testing.T) {
98+
path := filepath.Join(dir, "fake.mp3")
99+
test.CreateTestFile(t, path)
100+
101+
req := AudioRequest{
102+
FilePath: path,
103+
Model: "whisper-3",
104+
Prompt: "用简体中文",
105+
Temperature: 0.5,
106+
Language: "zh",
107+
Format: AudioResponseFormatSRT,
108+
}
109+
_, err := tc.createFn(ctx, req)
110+
checks.NoError(t, err, "audio API error")
111+
})
112+
}
113+
}
114+
115+
// handleAudioEndpoint Handles the completion endpoint by the test server.
116+
func handleAudioEndpoint(w http.ResponseWriter, r *http.Request) {
117+
var err error
118+
119+
// audio endpoints only accept POST requests
120+
if r.Method != "POST" {
121+
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
122+
}
123+
124+
mediaType, params, err := mime.ParseMediaType(r.Header.Get("Content-Type"))
125+
if err != nil {
126+
http.Error(w, "failed to parse media type", http.StatusBadRequest)
127+
return
128+
}
129+
130+
if !strings.HasPrefix(mediaType, "multipart") {
131+
http.Error(w, "request is not multipart", http.StatusBadRequest)
132+
}
133+
134+
boundary, ok := params["boundary"]
135+
if !ok {
136+
http.Error(w, "no boundary in params", http.StatusBadRequest)
137+
return
138+
}
139+
140+
fileData := &bytes.Buffer{}
141+
mr := multipart.NewReader(r.Body, boundary)
142+
part, err := mr.NextPart()
143+
if err != nil && errors.Is(err, io.EOF) {
144+
http.Error(w, "error accessing file", http.StatusBadRequest)
145+
return
146+
}
147+
if _, err = io.Copy(fileData, part); err != nil {
148+
http.Error(w, "failed to copy file", http.StatusInternalServerError)
149+
return
150+
}
151+
152+
if len(fileData.Bytes()) == 0 {
153+
w.WriteHeader(http.StatusInternalServerError)
154+
http.Error(w, "received empty file data", http.StatusBadRequest)
155+
return
156+
}
157+
158+
if _, err = w.Write([]byte(`{"body": "hello"}`)); err != nil {
159+
http.Error(w, "failed to write body", http.StatusInternalServerError)
160+
return
161+
}
162+
}

audio_test.go

-166
Original file line numberDiff line numberDiff line change
@@ -2,182 +2,16 @@ package openai //nolint:testpackage // testing private field
22

33
import (
44
"bytes"
5-
"context"
6-
"errors"
75
"fmt"
86
"io"
9-
"mime"
10-
"mime/multipart"
11-
"net/http"
127
"os"
138
"path/filepath"
14-
"strings"
159
"testing"
1610

1711
"github.com./sashabaranov/go-openai/internal/test"
1812
"github.com./sashabaranov/go-openai/internal/test/checks"
1913
)
2014

21-
// TestAudio Tests the transcription and translation endpoints of the API using the mocked server.
22-
func TestAudio(t *testing.T) {
23-
server := test.NewTestServer()
24-
server.RegisterHandler("/v1/audio/transcriptions", handleAudioEndpoint)
25-
server.RegisterHandler("/v1/audio/translations", handleAudioEndpoint)
26-
// create the test server
27-
var err error
28-
ts := server.OpenAITestServer()
29-
ts.Start()
30-
defer ts.Close()
31-
32-
config := DefaultConfig(test.GetTestToken())
33-
config.BaseURL = ts.URL + "/v1"
34-
client := NewClientWithConfig(config)
35-
36-
testcases := []struct {
37-
name string
38-
createFn func(context.Context, AudioRequest) (AudioResponse, error)
39-
}{
40-
{
41-
"transcribe",
42-
client.CreateTranscription,
43-
},
44-
{
45-
"translate",
46-
client.CreateTranslation,
47-
},
48-
}
49-
50-
ctx := context.Background()
51-
52-
dir, cleanup := test.CreateTestDirectory(t)
53-
defer cleanup()
54-
55-
for _, tc := range testcases {
56-
t.Run(tc.name, func(t *testing.T) {
57-
path := filepath.Join(dir, "fake.mp3")
58-
test.CreateTestFile(t, path)
59-
60-
req := AudioRequest{
61-
FilePath: path,
62-
Model: "whisper-3",
63-
}
64-
_, err = tc.createFn(ctx, req)
65-
checks.NoError(t, err, "audio API error")
66-
})
67-
68-
t.Run(tc.name+" (with reader)", func(t *testing.T) {
69-
req := AudioRequest{
70-
FilePath: "fake.webm",
71-
Reader: bytes.NewBuffer([]byte(`some webm binary data`)),
72-
Model: "whisper-3",
73-
}
74-
_, err = tc.createFn(ctx, req)
75-
checks.NoError(t, err, "audio API error")
76-
})
77-
}
78-
}
79-
80-
func TestAudioWithOptionalArgs(t *testing.T) {
81-
server := test.NewTestServer()
82-
server.RegisterHandler("/v1/audio/transcriptions", handleAudioEndpoint)
83-
server.RegisterHandler("/v1/audio/translations", handleAudioEndpoint)
84-
// create the test server
85-
var err error
86-
ts := server.OpenAITestServer()
87-
ts.Start()
88-
defer ts.Close()
89-
90-
config := DefaultConfig(test.GetTestToken())
91-
config.BaseURL = ts.URL + "/v1"
92-
client := NewClientWithConfig(config)
93-
94-
testcases := []struct {
95-
name string
96-
createFn func(context.Context, AudioRequest) (AudioResponse, error)
97-
}{
98-
{
99-
"transcribe",
100-
client.CreateTranscription,
101-
},
102-
{
103-
"translate",
104-
client.CreateTranslation,
105-
},
106-
}
107-
108-
ctx := context.Background()
109-
110-
dir, cleanup := test.CreateTestDirectory(t)
111-
defer cleanup()
112-
113-
for _, tc := range testcases {
114-
t.Run(tc.name, func(t *testing.T) {
115-
path := filepath.Join(dir, "fake.mp3")
116-
test.CreateTestFile(t, path)
117-
118-
req := AudioRequest{
119-
FilePath: path,
120-
Model: "whisper-3",
121-
Prompt: "用简体中文",
122-
Temperature: 0.5,
123-
Language: "zh",
124-
Format: AudioResponseFormatSRT,
125-
}
126-
_, err = tc.createFn(ctx, req)
127-
checks.NoError(t, err, "audio API error")
128-
})
129-
}
130-
}
131-
132-
// handleAudioEndpoint Handles the completion endpoint by the test server.
133-
func handleAudioEndpoint(w http.ResponseWriter, r *http.Request) {
134-
var err error
135-
136-
// audio endpoints only accept POST requests
137-
if r.Method != "POST" {
138-
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
139-
}
140-
141-
mediaType, params, err := mime.ParseMediaType(r.Header.Get("Content-Type"))
142-
if err != nil {
143-
http.Error(w, "failed to parse media type", http.StatusBadRequest)
144-
return
145-
}
146-
147-
if !strings.HasPrefix(mediaType, "multipart") {
148-
http.Error(w, "request is not multipart", http.StatusBadRequest)
149-
}
150-
151-
boundary, ok := params["boundary"]
152-
if !ok {
153-
http.Error(w, "no boundary in params", http.StatusBadRequest)
154-
return
155-
}
156-
157-
fileData := &bytes.Buffer{}
158-
mr := multipart.NewReader(r.Body, boundary)
159-
part, err := mr.NextPart()
160-
if err != nil && errors.Is(err, io.EOF) {
161-
http.Error(w, "error accessing file", http.StatusBadRequest)
162-
return
163-
}
164-
if _, err = io.Copy(fileData, part); err != nil {
165-
http.Error(w, "failed to copy file", http.StatusInternalServerError)
166-
return
167-
}
168-
169-
if len(fileData.Bytes()) == 0 {
170-
w.WriteHeader(http.StatusInternalServerError)
171-
http.Error(w, "received empty file data", http.StatusBadRequest)
172-
return
173-
}
174-
175-
if _, err = w.Write([]byte(`{"body": "hello"}`)); err != nil {
176-
http.Error(w, "failed to write body", http.StatusInternalServerError)
177-
return
178-
}
179-
}
180-
18115
func TestAudioWithFailingFormBuilder(t *testing.T) {
18216
dir, cleanup := test.CreateTestDirectory(t)
18317
defer cleanup()

0 commit comments

Comments
 (0)