mirror of
https://github.com/jmorganca/ollama
synced 2025-10-06 00:32:49 +02:00
The format qwen3-coder uses is relatively unique, both in rendering and in parsing. To implement parsing, I wrote a custom parser in similar style to harmony. For the rendering, I found that the logic would be much more difficult to follow in a template, so I introduced the concept of a built-in renderer that uses go code, rather than a template to generate prompts. I set us up for future built-in parsers and renderers by making it so they can be specified in a Modelfile like so: ``` RENDERER "qwen3-coder" PARSER "qwen3-coder" ``` These need to be provided explicitly because the architecture alone is not enough to understand what format the model expects to receive, and what format we expect it to output (e.g., qwen3-coder is `qwen3moe`, which includes other qwen3-family models as well) I haven't converted harmony to be one of these "built-ins" yet, since some of it is in flux with the changes @ParthSareen has been making to move harmony to the runner. It is likely that many other built-ins will need to move to the runner as well, but I'm able to slightly defer that decision since qwen3-coder doesn't have thinking (and therefore doesn't need to be in the runner to make structured outputs work). I expect to unify harmony with this approach very soon. Whether a particular model supports tools or thinking was previously inferred from templates, but without a template we now also use the parser itself to declare what it supports. If we have future models that re-use the same parsing format, but have different capabilities, we'll want to parameterize them and give them different names to be specified as a `PARSER`. Misc changes: - I worked on the renderer by diffing outputs from the reference implementation and ours. To make it easier to do this, I extended <https://github.com/ollama/ollama/pull/11875> to also support returning the prompt via the openai compat layer
1073 lines
27 KiB
Go
1073 lines
27 KiB
Go
// openai package provides middleware for partial compatibility with the OpenAI REST API
|
|
package openai
|
|
|
|
import (
|
|
"bytes"
|
|
"encoding/base64"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"log/slog"
|
|
"math/rand"
|
|
"net/http"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/gin-gonic/gin"
|
|
|
|
"github.com/ollama/ollama/api"
|
|
"github.com/ollama/ollama/types/model"
|
|
)
|
|
|
|
var finishReasonToolCalls = "tool_calls"
|
|
|
|
type Error struct {
|
|
Message string `json:"message"`
|
|
Type string `json:"type"`
|
|
Param any `json:"param"`
|
|
Code *string `json:"code"`
|
|
}
|
|
|
|
type ErrorResponse struct {
|
|
Error Error `json:"error"`
|
|
}
|
|
|
|
type Message struct {
|
|
Role string `json:"role"`
|
|
Content any `json:"content"`
|
|
Reasoning string `json:"reasoning,omitempty"`
|
|
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
|
Name string `json:"name,omitempty"`
|
|
ToolCallID string `json:"tool_call_id,omitempty"`
|
|
}
|
|
|
|
type Choice struct {
|
|
Index int `json:"index"`
|
|
Message Message `json:"message"`
|
|
FinishReason *string `json:"finish_reason"`
|
|
}
|
|
|
|
type ChunkChoice struct {
|
|
Index int `json:"index"`
|
|
Delta Message `json:"delta"`
|
|
FinishReason *string `json:"finish_reason"`
|
|
}
|
|
|
|
type CompleteChunkChoice struct {
|
|
Text string `json:"text"`
|
|
Index int `json:"index"`
|
|
FinishReason *string `json:"finish_reason"`
|
|
}
|
|
|
|
type Usage struct {
|
|
PromptTokens int `json:"prompt_tokens"`
|
|
CompletionTokens int `json:"completion_tokens"`
|
|
TotalTokens int `json:"total_tokens"`
|
|
}
|
|
|
|
type ResponseFormat struct {
|
|
Type string `json:"type"`
|
|
JsonSchema *JsonSchema `json:"json_schema,omitempty"`
|
|
}
|
|
|
|
type JsonSchema struct {
|
|
Schema json.RawMessage `json:"schema"`
|
|
}
|
|
|
|
type EmbedRequest struct {
|
|
Input any `json:"input"`
|
|
Model string `json:"model"`
|
|
Dimensions int `json:"dimensions,omitempty"`
|
|
}
|
|
|
|
type StreamOptions struct {
|
|
IncludeUsage bool `json:"include_usage"`
|
|
}
|
|
|
|
type Reasoning struct {
|
|
Effort *string `json:"effort,omitempty"`
|
|
}
|
|
|
|
type ChatCompletionRequest struct {
|
|
Model string `json:"model"`
|
|
Messages []Message `json:"messages"`
|
|
Stream bool `json:"stream"`
|
|
StreamOptions *StreamOptions `json:"stream_options"`
|
|
MaxTokens *int `json:"max_tokens"`
|
|
Seed *int `json:"seed"`
|
|
Stop any `json:"stop"`
|
|
Temperature *float64 `json:"temperature"`
|
|
FrequencyPenalty *float64 `json:"frequency_penalty"`
|
|
PresencePenalty *float64 `json:"presence_penalty"`
|
|
TopP *float64 `json:"top_p"`
|
|
ResponseFormat *ResponseFormat `json:"response_format"`
|
|
Tools []api.Tool `json:"tools"`
|
|
Reasoning *Reasoning `json:"reasoning,omitempty"`
|
|
ReasoningEffort *string `json:"reasoning_effort,omitempty"`
|
|
DebugRenderOnly bool `json:"_debug_render_only"`
|
|
}
|
|
|
|
type ChatCompletion struct {
|
|
Id string `json:"id"`
|
|
Object string `json:"object"`
|
|
Created int64 `json:"created"`
|
|
Model string `json:"model"`
|
|
SystemFingerprint string `json:"system_fingerprint"`
|
|
Choices []Choice `json:"choices"`
|
|
Usage Usage `json:"usage,omitempty"`
|
|
DebugInfo *api.DebugInfo `json:"_debug_info,omitempty"`
|
|
}
|
|
|
|
type ChatCompletionChunk struct {
|
|
Id string `json:"id"`
|
|
Object string `json:"object"`
|
|
Created int64 `json:"created"`
|
|
Model string `json:"model"`
|
|
SystemFingerprint string `json:"system_fingerprint"`
|
|
Choices []ChunkChoice `json:"choices"`
|
|
Usage *Usage `json:"usage,omitempty"`
|
|
}
|
|
|
|
// TODO (https://github.com/ollama/ollama/issues/5259): support []string, []int and [][]int
|
|
type CompletionRequest struct {
|
|
Model string `json:"model"`
|
|
Prompt string `json:"prompt"`
|
|
FrequencyPenalty float32 `json:"frequency_penalty"`
|
|
MaxTokens *int `json:"max_tokens"`
|
|
PresencePenalty float32 `json:"presence_penalty"`
|
|
Seed *int `json:"seed"`
|
|
Stop any `json:"stop"`
|
|
Stream bool `json:"stream"`
|
|
StreamOptions *StreamOptions `json:"stream_options"`
|
|
Temperature *float32 `json:"temperature"`
|
|
TopP float32 `json:"top_p"`
|
|
Suffix string `json:"suffix"`
|
|
DebugRenderOnly bool `json:"_debug_render_only"`
|
|
}
|
|
|
|
type Completion struct {
|
|
Id string `json:"id"`
|
|
Object string `json:"object"`
|
|
Created int64 `json:"created"`
|
|
Model string `json:"model"`
|
|
SystemFingerprint string `json:"system_fingerprint"`
|
|
Choices []CompleteChunkChoice `json:"choices"`
|
|
Usage Usage `json:"usage,omitempty"`
|
|
}
|
|
|
|
type CompletionChunk struct {
|
|
Id string `json:"id"`
|
|
Object string `json:"object"`
|
|
Created int64 `json:"created"`
|
|
Choices []CompleteChunkChoice `json:"choices"`
|
|
Model string `json:"model"`
|
|
SystemFingerprint string `json:"system_fingerprint"`
|
|
Usage *Usage `json:"usage,omitempty"`
|
|
}
|
|
|
|
type ToolCall struct {
|
|
ID string `json:"id"`
|
|
Index int `json:"index"`
|
|
Type string `json:"type"`
|
|
Function struct {
|
|
Name string `json:"name"`
|
|
Arguments string `json:"arguments"`
|
|
} `json:"function"`
|
|
}
|
|
|
|
type Model struct {
|
|
Id string `json:"id"`
|
|
Object string `json:"object"`
|
|
Created int64 `json:"created"`
|
|
OwnedBy string `json:"owned_by"`
|
|
}
|
|
|
|
type Embedding struct {
|
|
Object string `json:"object"`
|
|
Embedding []float32 `json:"embedding"`
|
|
Index int `json:"index"`
|
|
}
|
|
|
|
type ListCompletion struct {
|
|
Object string `json:"object"`
|
|
Data []Model `json:"data"`
|
|
}
|
|
|
|
type EmbeddingList struct {
|
|
Object string `json:"object"`
|
|
Data []Embedding `json:"data"`
|
|
Model string `json:"model"`
|
|
Usage EmbeddingUsage `json:"usage,omitempty"`
|
|
}
|
|
|
|
type EmbeddingUsage struct {
|
|
PromptTokens int `json:"prompt_tokens"`
|
|
TotalTokens int `json:"total_tokens"`
|
|
}
|
|
|
|
func NewError(code int, message string) ErrorResponse {
|
|
var etype string
|
|
switch code {
|
|
case http.StatusBadRequest:
|
|
etype = "invalid_request_error"
|
|
case http.StatusNotFound:
|
|
etype = "not_found_error"
|
|
default:
|
|
etype = "api_error"
|
|
}
|
|
|
|
return ErrorResponse{Error{Type: etype, Message: message}}
|
|
}
|
|
|
|
func toUsage(r api.ChatResponse) Usage {
|
|
return Usage{
|
|
PromptTokens: r.PromptEvalCount,
|
|
CompletionTokens: r.EvalCount,
|
|
TotalTokens: r.PromptEvalCount + r.EvalCount,
|
|
}
|
|
}
|
|
|
|
func toolCallId() string {
|
|
const letterBytes = "abcdefghijklmnopqrstuvwxyz0123456789"
|
|
b := make([]byte, 8)
|
|
for i := range b {
|
|
b[i] = letterBytes[rand.Intn(len(letterBytes))]
|
|
}
|
|
return "call_" + strings.ToLower(string(b))
|
|
}
|
|
|
|
func toToolCalls(tc []api.ToolCall) []ToolCall {
|
|
toolCalls := make([]ToolCall, len(tc))
|
|
for i, tc := range tc {
|
|
toolCalls[i].ID = toolCallId()
|
|
toolCalls[i].Type = "function"
|
|
toolCalls[i].Function.Name = tc.Function.Name
|
|
toolCalls[i].Index = tc.Function.Index
|
|
|
|
args, err := json.Marshal(tc.Function.Arguments)
|
|
if err != nil {
|
|
slog.Error("could not marshall function arguments to json", "error", err)
|
|
continue
|
|
}
|
|
|
|
toolCalls[i].Function.Arguments = string(args)
|
|
}
|
|
return toolCalls
|
|
}
|
|
|
|
func toChatCompletion(id string, r api.ChatResponse) ChatCompletion {
|
|
toolCalls := toToolCalls(r.Message.ToolCalls)
|
|
return ChatCompletion{
|
|
Id: id,
|
|
Object: "chat.completion",
|
|
Created: r.CreatedAt.Unix(),
|
|
Model: r.Model,
|
|
SystemFingerprint: "fp_ollama",
|
|
Choices: []Choice{{
|
|
Index: 0,
|
|
Message: Message{Role: r.Message.Role, Content: r.Message.Content, ToolCalls: toolCalls, Reasoning: r.Message.Thinking},
|
|
FinishReason: func(reason string) *string {
|
|
if len(toolCalls) > 0 {
|
|
reason = "tool_calls"
|
|
}
|
|
if len(reason) > 0 {
|
|
return &reason
|
|
}
|
|
return nil
|
|
}(r.DoneReason),
|
|
}}, Usage: toUsage(r),
|
|
DebugInfo: r.DebugInfo,
|
|
}
|
|
}
|
|
|
|
func toChunk(id string, r api.ChatResponse, toolCallSent bool) ChatCompletionChunk {
|
|
toolCalls := toToolCalls(r.Message.ToolCalls)
|
|
return ChatCompletionChunk{
|
|
Id: id,
|
|
Object: "chat.completion.chunk",
|
|
Created: time.Now().Unix(),
|
|
Model: r.Model,
|
|
SystemFingerprint: "fp_ollama",
|
|
Choices: []ChunkChoice{{
|
|
Index: 0,
|
|
Delta: Message{Role: "assistant", Content: r.Message.Content, ToolCalls: toolCalls, Reasoning: r.Message.Thinking},
|
|
FinishReason: func(reason string) *string {
|
|
if len(reason) > 0 {
|
|
if toolCallSent || len(toolCalls) > 0 {
|
|
return &finishReasonToolCalls
|
|
}
|
|
return &reason
|
|
}
|
|
return nil
|
|
}(r.DoneReason),
|
|
}},
|
|
}
|
|
}
|
|
|
|
func toUsageGenerate(r api.GenerateResponse) Usage {
|
|
return Usage{
|
|
PromptTokens: r.PromptEvalCount,
|
|
CompletionTokens: r.EvalCount,
|
|
TotalTokens: r.PromptEvalCount + r.EvalCount,
|
|
}
|
|
}
|
|
|
|
func toCompletion(id string, r api.GenerateResponse) Completion {
|
|
return Completion{
|
|
Id: id,
|
|
Object: "text_completion",
|
|
Created: r.CreatedAt.Unix(),
|
|
Model: r.Model,
|
|
SystemFingerprint: "fp_ollama",
|
|
Choices: []CompleteChunkChoice{{
|
|
Text: r.Response,
|
|
Index: 0,
|
|
FinishReason: func(reason string) *string {
|
|
if len(reason) > 0 {
|
|
return &reason
|
|
}
|
|
return nil
|
|
}(r.DoneReason),
|
|
}},
|
|
Usage: toUsageGenerate(r),
|
|
}
|
|
}
|
|
|
|
func toCompleteChunk(id string, r api.GenerateResponse) CompletionChunk {
|
|
return CompletionChunk{
|
|
Id: id,
|
|
Object: "text_completion",
|
|
Created: time.Now().Unix(),
|
|
Model: r.Model,
|
|
SystemFingerprint: "fp_ollama",
|
|
Choices: []CompleteChunkChoice{{
|
|
Text: r.Response,
|
|
Index: 0,
|
|
FinishReason: func(reason string) *string {
|
|
if len(reason) > 0 {
|
|
return &reason
|
|
}
|
|
return nil
|
|
}(r.DoneReason),
|
|
}},
|
|
}
|
|
}
|
|
|
|
func toListCompletion(r api.ListResponse) ListCompletion {
|
|
var data []Model
|
|
for _, m := range r.Models {
|
|
data = append(data, Model{
|
|
Id: m.Name,
|
|
Object: "model",
|
|
Created: m.ModifiedAt.Unix(),
|
|
OwnedBy: model.ParseName(m.Name).Namespace,
|
|
})
|
|
}
|
|
|
|
return ListCompletion{
|
|
Object: "list",
|
|
Data: data,
|
|
}
|
|
}
|
|
|
|
func toEmbeddingList(model string, r api.EmbedResponse) EmbeddingList {
|
|
if r.Embeddings != nil {
|
|
var data []Embedding
|
|
for i, e := range r.Embeddings {
|
|
data = append(data, Embedding{
|
|
Object: "embedding",
|
|
Embedding: e,
|
|
Index: i,
|
|
})
|
|
}
|
|
|
|
return EmbeddingList{
|
|
Object: "list",
|
|
Data: data,
|
|
Model: model,
|
|
Usage: EmbeddingUsage{
|
|
PromptTokens: r.PromptEvalCount,
|
|
TotalTokens: r.PromptEvalCount,
|
|
},
|
|
}
|
|
}
|
|
|
|
return EmbeddingList{}
|
|
}
|
|
|
|
func toModel(r api.ShowResponse, m string) Model {
|
|
return Model{
|
|
Id: m,
|
|
Object: "model",
|
|
Created: r.ModifiedAt.Unix(),
|
|
OwnedBy: model.ParseName(m).Namespace,
|
|
}
|
|
}
|
|
|
|
func fromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) {
|
|
var messages []api.Message
|
|
for _, msg := range r.Messages {
|
|
toolName := ""
|
|
if strings.ToLower(msg.Role) == "tool" {
|
|
toolName = msg.Name
|
|
if toolName == "" && msg.ToolCallID != "" {
|
|
toolName = nameFromToolCallID(r.Messages, msg.ToolCallID)
|
|
}
|
|
}
|
|
switch content := msg.Content.(type) {
|
|
case string:
|
|
toolCalls, err := fromCompletionToolCall(msg.ToolCalls)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
messages = append(messages, api.Message{Role: msg.Role, Content: content, Thinking: msg.Reasoning, ToolCalls: toolCalls, ToolName: toolName})
|
|
case []any:
|
|
for _, c := range content {
|
|
data, ok := c.(map[string]any)
|
|
if !ok {
|
|
return nil, errors.New("invalid message format")
|
|
}
|
|
switch data["type"] {
|
|
case "text":
|
|
text, ok := data["text"].(string)
|
|
if !ok {
|
|
return nil, errors.New("invalid message format")
|
|
}
|
|
messages = append(messages, api.Message{Role: msg.Role, Content: text})
|
|
case "image_url":
|
|
var url string
|
|
if urlMap, ok := data["image_url"].(map[string]any); ok {
|
|
if url, ok = urlMap["url"].(string); !ok {
|
|
return nil, errors.New("invalid message format")
|
|
}
|
|
} else {
|
|
if url, ok = data["image_url"].(string); !ok {
|
|
return nil, errors.New("invalid message format")
|
|
}
|
|
}
|
|
|
|
types := []string{"jpeg", "jpg", "png", "webp"}
|
|
valid := false
|
|
for _, t := range types {
|
|
prefix := "data:image/" + t + ";base64,"
|
|
if strings.HasPrefix(url, prefix) {
|
|
url = strings.TrimPrefix(url, prefix)
|
|
valid = true
|
|
break
|
|
}
|
|
}
|
|
|
|
if !valid {
|
|
return nil, errors.New("invalid image input")
|
|
}
|
|
|
|
img, err := base64.StdEncoding.DecodeString(url)
|
|
if err != nil {
|
|
return nil, errors.New("invalid message format")
|
|
}
|
|
|
|
messages = append(messages, api.Message{Role: msg.Role, Images: []api.ImageData{img}})
|
|
default:
|
|
return nil, errors.New("invalid message format")
|
|
}
|
|
}
|
|
// since we might have added multiple messages above, if we have tools
|
|
// calls we'll add them to the last message
|
|
if len(messages) > 0 && len(msg.ToolCalls) > 0 {
|
|
toolCalls, err := fromCompletionToolCall(msg.ToolCalls)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
messages[len(messages)-1].ToolCalls = toolCalls
|
|
if toolName != "" {
|
|
messages[len(messages)-1].ToolName = toolName
|
|
}
|
|
messages[len(messages)-1].Thinking = msg.Reasoning
|
|
}
|
|
default:
|
|
// content is only optional if tool calls are present
|
|
if msg.ToolCalls == nil {
|
|
return nil, fmt.Errorf("invalid message content type: %T", content)
|
|
}
|
|
|
|
toolCalls := make([]api.ToolCall, len(msg.ToolCalls))
|
|
for i, tc := range msg.ToolCalls {
|
|
toolCalls[i].Function.Name = tc.Function.Name
|
|
err := json.Unmarshal([]byte(tc.Function.Arguments), &toolCalls[i].Function.Arguments)
|
|
if err != nil {
|
|
return nil, errors.New("invalid tool call arguments")
|
|
}
|
|
}
|
|
messages = append(messages, api.Message{Role: msg.Role, Thinking: msg.Reasoning, ToolCalls: toolCalls})
|
|
}
|
|
}
|
|
|
|
options := make(map[string]any)
|
|
|
|
switch stop := r.Stop.(type) {
|
|
case string:
|
|
options["stop"] = []string{stop}
|
|
case []any:
|
|
var stops []string
|
|
for _, s := range stop {
|
|
if str, ok := s.(string); ok {
|
|
stops = append(stops, str)
|
|
}
|
|
}
|
|
options["stop"] = stops
|
|
}
|
|
|
|
if r.MaxTokens != nil {
|
|
options["num_predict"] = *r.MaxTokens
|
|
}
|
|
|
|
if r.Temperature != nil {
|
|
options["temperature"] = *r.Temperature
|
|
} else {
|
|
options["temperature"] = 1.0
|
|
}
|
|
|
|
if r.Seed != nil {
|
|
options["seed"] = *r.Seed
|
|
}
|
|
|
|
if r.FrequencyPenalty != nil {
|
|
options["frequency_penalty"] = *r.FrequencyPenalty
|
|
}
|
|
|
|
if r.PresencePenalty != nil {
|
|
options["presence_penalty"] = *r.PresencePenalty
|
|
}
|
|
|
|
if r.TopP != nil {
|
|
options["top_p"] = *r.TopP
|
|
} else {
|
|
options["top_p"] = 1.0
|
|
}
|
|
|
|
var format json.RawMessage
|
|
if r.ResponseFormat != nil {
|
|
switch strings.ToLower(strings.TrimSpace(r.ResponseFormat.Type)) {
|
|
// Support the old "json_object" type for OpenAI compatibility
|
|
case "json_object":
|
|
format = json.RawMessage(`"json"`)
|
|
case "json_schema":
|
|
if r.ResponseFormat.JsonSchema != nil {
|
|
format = r.ResponseFormat.JsonSchema.Schema
|
|
}
|
|
}
|
|
}
|
|
|
|
var think *api.ThinkValue
|
|
if r.Reasoning != nil {
|
|
think = &api.ThinkValue{
|
|
Value: *r.Reasoning.Effort,
|
|
}
|
|
} else if r.ReasoningEffort != nil {
|
|
think = &api.ThinkValue{
|
|
Value: *r.ReasoningEffort,
|
|
}
|
|
}
|
|
|
|
return &api.ChatRequest{
|
|
Model: r.Model,
|
|
Messages: messages,
|
|
Format: format,
|
|
Options: options,
|
|
Stream: &r.Stream,
|
|
Tools: r.Tools,
|
|
Think: think,
|
|
DebugRenderOnly: r.DebugRenderOnly,
|
|
}, nil
|
|
}
|
|
|
|
func nameFromToolCallID(messages []Message, toolCallID string) string {
|
|
// iterate backwards to be more resilient to duplicate tool call IDs (this
|
|
// follows "last one wins")
|
|
for i := len(messages) - 1; i >= 0; i-- {
|
|
msg := messages[i]
|
|
for _, tc := range msg.ToolCalls {
|
|
if tc.ID == toolCallID {
|
|
return tc.Function.Name
|
|
}
|
|
}
|
|
}
|
|
return ""
|
|
}
|
|
|
|
func fromCompletionToolCall(toolCalls []ToolCall) ([]api.ToolCall, error) {
|
|
apiToolCalls := make([]api.ToolCall, len(toolCalls))
|
|
for i, tc := range toolCalls {
|
|
apiToolCalls[i].Function.Name = tc.Function.Name
|
|
err := json.Unmarshal([]byte(tc.Function.Arguments), &apiToolCalls[i].Function.Arguments)
|
|
if err != nil {
|
|
return nil, errors.New("invalid tool call arguments")
|
|
}
|
|
}
|
|
|
|
return apiToolCalls, nil
|
|
}
|
|
|
|
func fromCompleteRequest(r CompletionRequest) (api.GenerateRequest, error) {
|
|
options := make(map[string]any)
|
|
|
|
switch stop := r.Stop.(type) {
|
|
case string:
|
|
options["stop"] = []string{stop}
|
|
case []any:
|
|
var stops []string
|
|
for _, s := range stop {
|
|
if str, ok := s.(string); ok {
|
|
stops = append(stops, str)
|
|
} else {
|
|
return api.GenerateRequest{}, fmt.Errorf("invalid type for 'stop' field: %T", s)
|
|
}
|
|
}
|
|
options["stop"] = stops
|
|
}
|
|
|
|
if r.MaxTokens != nil {
|
|
options["num_predict"] = *r.MaxTokens
|
|
}
|
|
|
|
if r.Temperature != nil {
|
|
options["temperature"] = *r.Temperature
|
|
} else {
|
|
options["temperature"] = 1.0
|
|
}
|
|
|
|
if r.Seed != nil {
|
|
options["seed"] = *r.Seed
|
|
}
|
|
|
|
options["frequency_penalty"] = r.FrequencyPenalty
|
|
|
|
options["presence_penalty"] = r.PresencePenalty
|
|
|
|
if r.TopP != 0.0 {
|
|
options["top_p"] = r.TopP
|
|
} else {
|
|
options["top_p"] = 1.0
|
|
}
|
|
|
|
return api.GenerateRequest{
|
|
Model: r.Model,
|
|
Prompt: r.Prompt,
|
|
Options: options,
|
|
Stream: &r.Stream,
|
|
Suffix: r.Suffix,
|
|
DebugRenderOnly: r.DebugRenderOnly,
|
|
}, nil
|
|
}
|
|
|
|
type BaseWriter struct {
|
|
gin.ResponseWriter
|
|
}
|
|
|
|
type ChatWriter struct {
|
|
stream bool
|
|
streamOptions *StreamOptions
|
|
id string
|
|
toolCallSent bool
|
|
BaseWriter
|
|
}
|
|
|
|
type CompleteWriter struct {
|
|
stream bool
|
|
streamOptions *StreamOptions
|
|
id string
|
|
BaseWriter
|
|
}
|
|
|
|
type ListWriter struct {
|
|
BaseWriter
|
|
}
|
|
|
|
type RetrieveWriter struct {
|
|
BaseWriter
|
|
model string
|
|
}
|
|
|
|
type EmbedWriter struct {
|
|
BaseWriter
|
|
model string
|
|
}
|
|
|
|
func (w *BaseWriter) writeError(data []byte) (int, error) {
|
|
var serr api.StatusError
|
|
err := json.Unmarshal(data, &serr)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
w.ResponseWriter.Header().Set("Content-Type", "application/json")
|
|
err = json.NewEncoder(w.ResponseWriter).Encode(NewError(http.StatusInternalServerError, serr.Error()))
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
return len(data), nil
|
|
}
|
|
|
|
func (w *ChatWriter) writeResponse(data []byte) (int, error) {
|
|
var chatResponse api.ChatResponse
|
|
err := json.Unmarshal(data, &chatResponse)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
// chat chunk
|
|
if w.stream {
|
|
c := toChunk(w.id, chatResponse, w.toolCallSent)
|
|
d, err := json.Marshal(c)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
if !w.toolCallSent && len(c.Choices) > 0 && len(c.Choices[0].Delta.ToolCalls) > 0 {
|
|
w.toolCallSent = true
|
|
}
|
|
|
|
w.ResponseWriter.Header().Set("Content-Type", "text/event-stream")
|
|
_, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d)))
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
if chatResponse.Done {
|
|
if w.streamOptions != nil && w.streamOptions.IncludeUsage {
|
|
u := toUsage(chatResponse)
|
|
c.Usage = &u
|
|
c.Choices = []ChunkChoice{}
|
|
d, err := json.Marshal(c)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
_, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d)))
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
}
|
|
_, err = w.ResponseWriter.Write([]byte("data: [DONE]\n\n"))
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
}
|
|
|
|
return len(data), nil
|
|
}
|
|
|
|
// chat completion
|
|
w.ResponseWriter.Header().Set("Content-Type", "application/json")
|
|
err = json.NewEncoder(w.ResponseWriter).Encode(toChatCompletion(w.id, chatResponse))
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
return len(data), nil
|
|
}
|
|
|
|
func (w *ChatWriter) Write(data []byte) (int, error) {
|
|
code := w.ResponseWriter.Status()
|
|
if code != http.StatusOK {
|
|
return w.writeError(data)
|
|
}
|
|
|
|
return w.writeResponse(data)
|
|
}
|
|
|
|
func (w *CompleteWriter) writeResponse(data []byte) (int, error) {
|
|
var generateResponse api.GenerateResponse
|
|
err := json.Unmarshal(data, &generateResponse)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
// completion chunk
|
|
if w.stream {
|
|
c := toCompleteChunk(w.id, generateResponse)
|
|
if w.streamOptions != nil && w.streamOptions.IncludeUsage {
|
|
c.Usage = &Usage{}
|
|
}
|
|
d, err := json.Marshal(c)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
w.ResponseWriter.Header().Set("Content-Type", "text/event-stream")
|
|
_, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d)))
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
if generateResponse.Done {
|
|
if w.streamOptions != nil && w.streamOptions.IncludeUsage {
|
|
u := toUsageGenerate(generateResponse)
|
|
c.Usage = &u
|
|
c.Choices = []CompleteChunkChoice{}
|
|
d, err := json.Marshal(c)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
_, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d)))
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
}
|
|
_, err = w.ResponseWriter.Write([]byte("data: [DONE]\n\n"))
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
}
|
|
|
|
return len(data), nil
|
|
}
|
|
|
|
// completion
|
|
w.ResponseWriter.Header().Set("Content-Type", "application/json")
|
|
err = json.NewEncoder(w.ResponseWriter).Encode(toCompletion(w.id, generateResponse))
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
return len(data), nil
|
|
}
|
|
|
|
func (w *CompleteWriter) Write(data []byte) (int, error) {
|
|
code := w.ResponseWriter.Status()
|
|
if code != http.StatusOK {
|
|
return w.writeError(data)
|
|
}
|
|
|
|
return w.writeResponse(data)
|
|
}
|
|
|
|
func (w *ListWriter) writeResponse(data []byte) (int, error) {
|
|
var listResponse api.ListResponse
|
|
err := json.Unmarshal(data, &listResponse)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
w.ResponseWriter.Header().Set("Content-Type", "application/json")
|
|
err = json.NewEncoder(w.ResponseWriter).Encode(toListCompletion(listResponse))
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
return len(data), nil
|
|
}
|
|
|
|
func (w *ListWriter) Write(data []byte) (int, error) {
|
|
code := w.ResponseWriter.Status()
|
|
if code != http.StatusOK {
|
|
return w.writeError(data)
|
|
}
|
|
|
|
return w.writeResponse(data)
|
|
}
|
|
|
|
func (w *RetrieveWriter) writeResponse(data []byte) (int, error) {
|
|
var showResponse api.ShowResponse
|
|
err := json.Unmarshal(data, &showResponse)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
// retrieve completion
|
|
w.ResponseWriter.Header().Set("Content-Type", "application/json")
|
|
err = json.NewEncoder(w.ResponseWriter).Encode(toModel(showResponse, w.model))
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
return len(data), nil
|
|
}
|
|
|
|
func (w *RetrieveWriter) Write(data []byte) (int, error) {
|
|
code := w.ResponseWriter.Status()
|
|
if code != http.StatusOK {
|
|
return w.writeError(data)
|
|
}
|
|
|
|
return w.writeResponse(data)
|
|
}
|
|
|
|
func (w *EmbedWriter) writeResponse(data []byte) (int, error) {
|
|
var embedResponse api.EmbedResponse
|
|
err := json.Unmarshal(data, &embedResponse)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
w.ResponseWriter.Header().Set("Content-Type", "application/json")
|
|
err = json.NewEncoder(w.ResponseWriter).Encode(toEmbeddingList(w.model, embedResponse))
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
return len(data), nil
|
|
}
|
|
|
|
func (w *EmbedWriter) Write(data []byte) (int, error) {
|
|
code := w.ResponseWriter.Status()
|
|
if code != http.StatusOK {
|
|
return w.writeError(data)
|
|
}
|
|
|
|
return w.writeResponse(data)
|
|
}
|
|
|
|
func ListMiddleware() gin.HandlerFunc {
|
|
return func(c *gin.Context) {
|
|
w := &ListWriter{
|
|
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
|
|
}
|
|
|
|
c.Writer = w
|
|
|
|
c.Next()
|
|
}
|
|
}
|
|
|
|
func RetrieveMiddleware() gin.HandlerFunc {
|
|
return func(c *gin.Context) {
|
|
var b bytes.Buffer
|
|
if err := json.NewEncoder(&b).Encode(api.ShowRequest{Name: c.Param("model")}); err != nil {
|
|
c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error()))
|
|
return
|
|
}
|
|
|
|
c.Request.Body = io.NopCloser(&b)
|
|
|
|
// response writer
|
|
w := &RetrieveWriter{
|
|
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
|
|
model: c.Param("model"),
|
|
}
|
|
|
|
c.Writer = w
|
|
|
|
c.Next()
|
|
}
|
|
}
|
|
|
|
func CompletionsMiddleware() gin.HandlerFunc {
|
|
return func(c *gin.Context) {
|
|
var req CompletionRequest
|
|
err := c.ShouldBindJSON(&req)
|
|
if err != nil {
|
|
c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
|
|
return
|
|
}
|
|
|
|
var b bytes.Buffer
|
|
genReq, err := fromCompleteRequest(req)
|
|
if err != nil {
|
|
c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
|
|
return
|
|
}
|
|
|
|
if err := json.NewEncoder(&b).Encode(genReq); err != nil {
|
|
c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error()))
|
|
return
|
|
}
|
|
|
|
c.Request.Body = io.NopCloser(&b)
|
|
|
|
w := &CompleteWriter{
|
|
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
|
|
stream: req.Stream,
|
|
id: fmt.Sprintf("cmpl-%d", rand.Intn(999)),
|
|
streamOptions: req.StreamOptions,
|
|
}
|
|
|
|
c.Writer = w
|
|
c.Next()
|
|
}
|
|
}
|
|
|
|
func EmbeddingsMiddleware() gin.HandlerFunc {
|
|
return func(c *gin.Context) {
|
|
var req EmbedRequest
|
|
err := c.ShouldBindJSON(&req)
|
|
if err != nil {
|
|
c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
|
|
return
|
|
}
|
|
|
|
if req.Input == "" {
|
|
req.Input = []string{""}
|
|
}
|
|
|
|
if req.Input == nil {
|
|
c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, "invalid input"))
|
|
return
|
|
}
|
|
|
|
if v, ok := req.Input.([]any); ok && len(v) == 0 {
|
|
c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, "invalid input"))
|
|
return
|
|
}
|
|
|
|
var b bytes.Buffer
|
|
if err := json.NewEncoder(&b).Encode(api.EmbedRequest{Model: req.Model, Input: req.Input, Dimensions: req.Dimensions}); err != nil {
|
|
c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error()))
|
|
return
|
|
}
|
|
|
|
c.Request.Body = io.NopCloser(&b)
|
|
|
|
w := &EmbedWriter{
|
|
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
|
|
model: req.Model,
|
|
}
|
|
|
|
c.Writer = w
|
|
|
|
c.Next()
|
|
}
|
|
}
|
|
|
|
func ChatMiddleware() gin.HandlerFunc {
|
|
return func(c *gin.Context) {
|
|
var req ChatCompletionRequest
|
|
err := c.ShouldBindJSON(&req)
|
|
if err != nil {
|
|
c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
|
|
return
|
|
}
|
|
|
|
if len(req.Messages) == 0 {
|
|
c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, "[] is too short - 'messages'"))
|
|
return
|
|
}
|
|
|
|
var b bytes.Buffer
|
|
|
|
chatReq, err := fromChatRequest(req)
|
|
if err != nil {
|
|
c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
|
|
return
|
|
}
|
|
|
|
if err := json.NewEncoder(&b).Encode(chatReq); err != nil {
|
|
c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error()))
|
|
return
|
|
}
|
|
|
|
c.Request.Body = io.NopCloser(&b)
|
|
|
|
w := &ChatWriter{
|
|
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
|
|
stream: req.Stream,
|
|
id: fmt.Sprintf("chatcmpl-%d", rand.Intn(999)),
|
|
streamOptions: req.StreamOptions,
|
|
}
|
|
|
|
c.Writer = w
|
|
|
|
c.Next()
|
|
}
|
|
}
|