Skip to content

Commit 2054db0

Browse files
authored
Add support for O3-mini (#930)
* Add support for O3-mini - Add support for the o3 mini set of models, including tests that match the constraints in OpenAI's API docs (https://platform.openai.com/docs/models#o3-mini). * Deprecate and refactor - Deprecate `ErrO1BetaLimitationsLogprobs` and `ErrO1BetaLimitationsOther` - Implement `validationRequestForReasoningModels`, which works on both o1 & o3, and has per-model-type restrictions on functionality (eg, o3 class are allowed function calls and system messages, o1 isn't) * Move reasoning validation to `reasoning_validator.go` - Add a `NewReasoningValidator` which exposes a `Validate()` method for a given request - Also adds a test for chat streams * Final nits
1 parent 45aa996 commit 2054db0

File tree

6 files changed

+431
-92
lines changed

6 files changed

+431
-92
lines changed

chat.go

+2-1
Original file line numberDiff line numberDiff line change
@@ -392,7 +392,8 @@ func (c *Client) CreateChatCompletion(
392392
return
393393
}
394394

395-
if err = validateRequestForO1Models(request); err != nil {
395+
reasoningValidator := NewReasoningValidator()
396+
if err = reasoningValidator.Validate(request); err != nil {
396397
return
397398
}
398399

chat_stream.go

+2-1
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,8 @@ func (c *Client) CreateChatCompletionStream(
8080
}
8181

8282
request.Stream = true
83-
if err = validateRequestForO1Models(request); err != nil {
83+
reasoningValidator := NewReasoningValidator()
84+
if err = reasoningValidator.Validate(request); err != nil {
8485
return
8586
}
8687

chat_stream_test.go

+167
Original file line numberDiff line numberDiff line change
@@ -792,6 +792,173 @@ func compareChatResponses(r1, r2 openai.ChatCompletionStreamResponse) bool {
792792
return true
793793
}
794794

795+
func TestCreateChatCompletionStreamWithReasoningModel(t *testing.T) {
796+
client, server, teardown := setupOpenAITestServer()
797+
defer teardown()
798+
server.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, _ *http.Request) {
799+
w.Header().Set("Content-Type", "text/event-stream")
800+
801+
dataBytes := []byte{}
802+
803+
//nolint:lll
804+
dataBytes = append(dataBytes, []byte(`data: {"id":"1","object":"chat.completion.chunk","created":1729585728,"model":"o3-mini-2025-01-31","system_fingerprint":"fp_mini","choices":[{"index":0,"delta":{"role":"assistant","content":""},"finish_reason":null}]}`)...)
805+
dataBytes = append(dataBytes, []byte("\n\n")...)
806+
807+
//nolint:lll
808+
dataBytes = append(dataBytes, []byte(`data: {"id":"2","object":"chat.completion.chunk","created":1729585728,"model":"o3-mini-2025-01-31","system_fingerprint":"fp_mini","choices":[{"index":0,"delta":{"content":"Hello"},"finish_reason":null}]}`)...)
809+
dataBytes = append(dataBytes, []byte("\n\n")...)
810+
811+
//nolint:lll
812+
dataBytes = append(dataBytes, []byte(`data: {"id":"3","object":"chat.completion.chunk","created":1729585728,"model":"o3-mini-2025-01-31","system_fingerprint":"fp_mini","choices":[{"index":0,"delta":{"content":" from"},"finish_reason":null}]}`)...)
813+
dataBytes = append(dataBytes, []byte("\n\n")...)
814+
815+
//nolint:lll
816+
dataBytes = append(dataBytes, []byte(`data: {"id":"4","object":"chat.completion.chunk","created":1729585728,"model":"o3-mini-2025-01-31","system_fingerprint":"fp_mini","choices":[{"index":0,"delta":{"content":" O3Mini"},"finish_reason":null}]}`)...)
817+
dataBytes = append(dataBytes, []byte("\n\n")...)
818+
819+
//nolint:lll
820+
dataBytes = append(dataBytes, []byte(`data: {"id":"5","object":"chat.completion.chunk","created":1729585728,"model":"o3-mini-2025-01-31","system_fingerprint":"fp_mini","choices":[{"index":0,"delta":{},"finish_reason":"stop"}]}`)...)
821+
dataBytes = append(dataBytes, []byte("\n\n")...)
822+
823+
dataBytes = append(dataBytes, []byte("data: [DONE]\n\n")...)
824+
825+
_, err := w.Write(dataBytes)
826+
checks.NoError(t, err, "Write error")
827+
})
828+
829+
stream, err := client.CreateChatCompletionStream(context.Background(), openai.ChatCompletionRequest{
830+
MaxCompletionTokens: 2000,
831+
Model: openai.O3Mini20250131,
832+
Messages: []openai.ChatCompletionMessage{
833+
{
834+
Role: openai.ChatMessageRoleUser,
835+
Content: "Hello!",
836+
},
837+
},
838+
Stream: true,
839+
})
840+
checks.NoError(t, err, "CreateCompletionStream returned error")
841+
defer stream.Close()
842+
843+
expectedResponses := []openai.ChatCompletionStreamResponse{
844+
{
845+
ID: "1",
846+
Object: "chat.completion.chunk",
847+
Created: 1729585728,
848+
Model: openai.O3Mini20250131,
849+
SystemFingerprint: "fp_mini",
850+
Choices: []openai.ChatCompletionStreamChoice{
851+
{
852+
Index: 0,
853+
Delta: openai.ChatCompletionStreamChoiceDelta{
854+
Role: "assistant",
855+
},
856+
},
857+
},
858+
},
859+
{
860+
ID: "2",
861+
Object: "chat.completion.chunk",
862+
Created: 1729585728,
863+
Model: openai.O3Mini20250131,
864+
SystemFingerprint: "fp_mini",
865+
Choices: []openai.ChatCompletionStreamChoice{
866+
{
867+
Index: 0,
868+
Delta: openai.ChatCompletionStreamChoiceDelta{
869+
Content: "Hello",
870+
},
871+
},
872+
},
873+
},
874+
{
875+
ID: "3",
876+
Object: "chat.completion.chunk",
877+
Created: 1729585728,
878+
Model: openai.O3Mini20250131,
879+
SystemFingerprint: "fp_mini",
880+
Choices: []openai.ChatCompletionStreamChoice{
881+
{
882+
Index: 0,
883+
Delta: openai.ChatCompletionStreamChoiceDelta{
884+
Content: " from",
885+
},
886+
},
887+
},
888+
},
889+
{
890+
ID: "4",
891+
Object: "chat.completion.chunk",
892+
Created: 1729585728,
893+
Model: openai.O3Mini20250131,
894+
SystemFingerprint: "fp_mini",
895+
Choices: []openai.ChatCompletionStreamChoice{
896+
{
897+
Index: 0,
898+
Delta: openai.ChatCompletionStreamChoiceDelta{
899+
Content: " O3Mini",
900+
},
901+
},
902+
},
903+
},
904+
{
905+
ID: "5",
906+
Object: "chat.completion.chunk",
907+
Created: 1729585728,
908+
Model: openai.O3Mini20250131,
909+
SystemFingerprint: "fp_mini",
910+
Choices: []openai.ChatCompletionStreamChoice{
911+
{
912+
Index: 0,
913+
Delta: openai.ChatCompletionStreamChoiceDelta{},
914+
FinishReason: "stop",
915+
},
916+
},
917+
},
918+
}
919+
920+
for ix, expectedResponse := range expectedResponses {
921+
b, _ := json.Marshal(expectedResponse)
922+
t.Logf("%d: %s", ix, string(b))
923+
924+
receivedResponse, streamErr := stream.Recv()
925+
checks.NoError(t, streamErr, "stream.Recv() failed")
926+
if !compareChatResponses(expectedResponse, receivedResponse) {
927+
t.Errorf("Stream response %v is %v, expected %v", ix, receivedResponse, expectedResponse)
928+
}
929+
}
930+
931+
_, streamErr := stream.Recv()
932+
if !errors.Is(streamErr, io.EOF) {
933+
t.Errorf("stream.Recv() did not return EOF in the end: %v", streamErr)
934+
}
935+
}
936+
937+
func TestCreateChatCompletionStreamReasoningValidatorFails(t *testing.T) {
938+
client, _, _ := setupOpenAITestServer()
939+
940+
stream, err := client.CreateChatCompletionStream(context.Background(), openai.ChatCompletionRequest{
941+
MaxTokens: 100, // This will trigger the validator to fail
942+
Model: openai.O3Mini,
943+
Messages: []openai.ChatCompletionMessage{
944+
{
945+
Role: openai.ChatMessageRoleUser,
946+
Content: "Hello!",
947+
},
948+
},
949+
Stream: true,
950+
})
951+
952+
if stream != nil {
953+
t.Error("Expected nil stream when validation fails")
954+
stream.Close()
955+
}
956+
957+
if !errors.Is(err, openai.ErrReasoningModelMaxTokensDeprecated) {
958+
t.Errorf("Expected ErrReasoningModelMaxTokensDeprecated, got: %v", err)
959+
}
960+
}
961+
795962
func compareChatStreamResponseChoices(c1, c2 openai.ChatCompletionStreamChoice) bool {
796963
if c1.Index != c2.Index {
797964
return false

0 commit comments

Comments
 (0)