mirror of
https://github.com/jmorganca/ollama
synced 2025-10-06 00:32:49 +02:00
550 lines
14 KiB
Go
550 lines
14 KiB
Go
package template
|
||
|
||
import (
|
||
"bufio"
|
||
"bytes"
|
||
"encoding/json"
|
||
"io"
|
||
"os"
|
||
"path/filepath"
|
||
"slices"
|
||
"strings"
|
||
"testing"
|
||
|
||
"github.com/google/go-cmp/cmp"
|
||
|
||
"github.com/ollama/ollama/api"
|
||
"github.com/ollama/ollama/fs/ggml"
|
||
)
|
||
|
||
func TestNamed(t *testing.T) {
|
||
f, err := os.Open(filepath.Join("testdata", "templates.jsonl"))
|
||
if err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
defer f.Close()
|
||
|
||
scanner := bufio.NewScanner(f)
|
||
for scanner.Scan() {
|
||
var ss map[string]string
|
||
if err := json.Unmarshal(scanner.Bytes(), &ss); err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
|
||
for k, v := range ss {
|
||
t.Run(k, func(t *testing.T) {
|
||
kv := ggml.KV{"tokenizer.chat_template": v}
|
||
s := kv.ChatTemplate()
|
||
r, err := Named(s)
|
||
if err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
|
||
if r.Name != k {
|
||
t.Errorf("expected %q, got %q", k, r.Name)
|
||
}
|
||
|
||
var b bytes.Buffer
|
||
if _, err := io.Copy(&b, r.Reader()); err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
|
||
tmpl, err := Parse(b.String())
|
||
if err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
|
||
if tmpl.Tree.Root.String() == "" {
|
||
t.Errorf("empty %s template", k)
|
||
}
|
||
})
|
||
}
|
||
}
|
||
}
|
||
|
||
func TestTemplate(t *testing.T) {
|
||
cases := make(map[string][]api.Message)
|
||
for _, mm := range [][]api.Message{
|
||
{
|
||
{Role: "user", Content: "Hello, how are you?"},
|
||
},
|
||
{
|
||
{Role: "user", Content: "Hello, how are you?"},
|
||
{Role: "assistant", Content: "I'm doing great. How can I help you today?"},
|
||
{Role: "user", Content: "I'd like to show off how chat templating works!"},
|
||
},
|
||
{
|
||
{Role: "system", Content: "You are a helpful assistant."},
|
||
{Role: "user", Content: "Hello, how are you?"},
|
||
{Role: "assistant", Content: "I'm doing great. How can I help you today?"},
|
||
{Role: "user", Content: "I'd like to show off how chat templating works!"},
|
||
},
|
||
} {
|
||
var roles []string
|
||
for _, m := range mm {
|
||
roles = append(roles, m.Role)
|
||
}
|
||
|
||
cases[strings.Join(roles, "-")] = mm
|
||
}
|
||
|
||
matches, err := filepath.Glob("*.gotmpl")
|
||
if err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
|
||
for _, match := range matches {
|
||
t.Run(match, func(t *testing.T) {
|
||
bts, err := os.ReadFile(match)
|
||
if err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
|
||
tmpl, err := Parse(string(bts))
|
||
if err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
|
||
for n, tt := range cases {
|
||
var actual bytes.Buffer
|
||
t.Run(n, func(t *testing.T) {
|
||
if err := tmpl.Execute(&actual, Values{Messages: tt}); err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
|
||
expect, err := os.ReadFile(filepath.Join("testdata", match, n))
|
||
if err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
|
||
bts := actual.Bytes()
|
||
|
||
if slices.Contains([]string{"chatqa.gotmpl", "llama2-chat.gotmpl", "mistral-instruct.gotmpl", "openchat.gotmpl", "vicuna.gotmpl"}, match) && bts[len(bts)-1] == ' ' {
|
||
t.Log("removing trailing space from output")
|
||
bts = bts[:len(bts)-1]
|
||
}
|
||
|
||
if diff := cmp.Diff(bts, expect); diff != "" {
|
||
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||
}
|
||
})
|
||
|
||
t.Run("legacy", func(t *testing.T) {
|
||
t.Skip("legacy outputs are currently default outputs")
|
||
var legacy bytes.Buffer
|
||
if err := tmpl.Execute(&legacy, Values{Messages: tt, forceLegacy: true}); err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
|
||
legacyBytes := legacy.Bytes()
|
||
if slices.Contains([]string{"chatqa.gotmpl", "openchat.gotmpl", "vicuna.gotmpl"}, match) && legacyBytes[len(legacyBytes)-1] == ' ' {
|
||
t.Log("removing trailing space from legacy output")
|
||
legacyBytes = legacyBytes[:len(legacyBytes)-1]
|
||
} else if slices.Contains([]string{"codellama-70b-instruct.gotmpl", "llama2-chat.gotmpl", "mistral-instruct.gotmpl"}, match) {
|
||
t.Skip("legacy outputs cannot be compared to messages outputs")
|
||
}
|
||
|
||
if diff := cmp.Diff(legacyBytes, actual.Bytes()); diff != "" {
|
||
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||
}
|
||
})
|
||
}
|
||
})
|
||
}
|
||
}
|
||
|
||
func TestParse(t *testing.T) {
|
||
validCases := []struct {
|
||
name string
|
||
template string
|
||
vars []string
|
||
}{
|
||
{
|
||
name: "PromptOnly",
|
||
template: "{{ .Prompt }}",
|
||
vars: []string{"prompt", "response"},
|
||
},
|
||
{
|
||
name: "SystemAndPrompt",
|
||
template: "{{ .System }} {{ .Prompt }}",
|
||
vars: []string{"prompt", "response", "system"},
|
||
},
|
||
{
|
||
name: "PromptResponseSystem",
|
||
template: "{{ .System }} {{ .Prompt }} {{ .Response }}",
|
||
vars: []string{"prompt", "response", "system"},
|
||
},
|
||
{
|
||
name: "ToolsBlock",
|
||
template: "{{ with .Tools }}{{ . }}{{ end }} {{ .System }} {{ .Prompt }}",
|
||
vars: []string{"prompt", "response", "system", "tools"},
|
||
},
|
||
{
|
||
name: "MessagesRange",
|
||
template: "{{ range .Messages }}{{ .Role }} {{ .Content }}{{ end }}",
|
||
vars: []string{"content", "messages", "role"},
|
||
},
|
||
{
|
||
name: "ToolResultConditional",
|
||
template: "{{ range .Messages }}{{ if eq .Role \"tool\" }}Tool Result: {{ .ToolName }} {{ .Content }}{{ end }}{{ end }}",
|
||
vars: []string{"content", "messages", "role", "toolname"},
|
||
},
|
||
{
|
||
name: "MultilineSystemUserAssistant",
|
||
template: `{{- range .Messages }}
|
||
{{- if eq .Role "system" }}SYSTEM:
|
||
{{- else if eq .Role "user" }}USER:
|
||
{{- else if eq .Role "assistant" }}ASSISTANT:
|
||
{{- else if eq .Role "tool" }}TOOL:
|
||
{{- end }} {{ .Content }}
|
||
{{- end }}`,
|
||
vars: []string{"content", "messages", "role"},
|
||
},
|
||
{
|
||
name: "ChatMLLike",
|
||
template: `{{- if .Messages }}
|
||
{{- range .Messages }}<|im_start|>{{ .Role }}
|
||
{{ .Content }}<|im_end|>
|
||
{{ end }}<|im_start|>assistant
|
||
{{ else -}}
|
||
{{ if .System }}<|im_start|>system
|
||
{{ .System }}<|im_end|>
|
||
{{ end }}{{ if .Prompt }}<|im_start|>user
|
||
{{ .Prompt }}<|im_end|>
|
||
{{ end }}<|im_start|>assistant
|
||
{{ .Response }}<|im_end|>
|
||
{{- end -}}`,
|
||
vars: []string{"content", "messages", "prompt", "response", "role", "system"},
|
||
},
|
||
}
|
||
|
||
for _, tt := range validCases {
|
||
tt := tt
|
||
t.Run(tt.name, func(t *testing.T) {
|
||
t.Parallel()
|
||
|
||
tmpl, err := Parse(tt.template)
|
||
if err != nil {
|
||
t.Fatalf("Parse returned unexpected error: %v", err)
|
||
}
|
||
|
||
gotVars, err := tmpl.Vars()
|
||
if err != nil {
|
||
t.Fatalf("Vars returned unexpected error: %v", err)
|
||
}
|
||
|
||
if diff := cmp.Diff(gotVars, tt.vars); diff != "" {
|
||
t.Errorf("Vars mismatch (-got +want):\n%s", diff)
|
||
}
|
||
})
|
||
}
|
||
}
|
||
|
||
func TestParseError(t *testing.T) {
|
||
invalidCases := []struct {
|
||
name string
|
||
template string
|
||
errorStr string
|
||
}{
|
||
{
|
||
"TemplateNotClosed",
|
||
"{{ .Prompt ",
|
||
"unclosed action",
|
||
},
|
||
{
|
||
"Template",
|
||
`{{define "x"}}{{template "x"}}{{end}}{{template "x"}}`,
|
||
"undefined template specified",
|
||
},
|
||
}
|
||
|
||
for _, tt := range invalidCases {
|
||
t.Run(tt.name, func(t *testing.T) {
|
||
_, err := Parse(tt.template)
|
||
if err == nil {
|
||
t.Fatalf("expected Parse to return an error for an invalid template, got nil")
|
||
}
|
||
|
||
if !strings.Contains(strings.ToLower(err.Error()), strings.ToLower(tt.errorStr)) {
|
||
t.Errorf("unexpected error message.\n got: %q\n want substring (case‑insensitive): %q", err.Error(), tt.errorStr)
|
||
}
|
||
})
|
||
}
|
||
}
|
||
|
||
func TestExecuteWithMessages(t *testing.T) {
|
||
type template struct {
|
||
name string
|
||
template string
|
||
}
|
||
cases := []struct {
|
||
name string
|
||
templates []template
|
||
values Values
|
||
expected string
|
||
}{
|
||
{
|
||
"mistral",
|
||
[]template{
|
||
{"no response", `[INST] {{ if .System }}{{ .System }}
|
||
|
||
{{ end }}{{ .Prompt }}[/INST] `},
|
||
{"response", `[INST] {{ if .System }}{{ .System }}
|
||
|
||
{{ end }}{{ .Prompt }}[/INST] {{ .Response }}`},
|
||
{"messages", `[INST] {{ if .System }}{{ .System }}
|
||
|
||
{{ end }}
|
||
{{- range .Messages }}
|
||
{{- if eq .Role "user" }}{{ .Content }}[/INST] {{ else if eq .Role "assistant" }}{{ .Content }}[INST] {{ end }}
|
||
{{- end }}`},
|
||
},
|
||
Values{
|
||
Messages: []api.Message{
|
||
{Role: "user", Content: "Hello friend!"},
|
||
{Role: "assistant", Content: "Hello human!"},
|
||
{Role: "user", Content: "What is your name?"},
|
||
},
|
||
},
|
||
`[INST] Hello friend![/INST] Hello human![INST] What is your name?[/INST] `,
|
||
},
|
||
{
|
||
"mistral system",
|
||
[]template{
|
||
{"no response", `[INST] {{ if .System }}{{ .System }}
|
||
|
||
{{ end }}{{ .Prompt }}[/INST] `},
|
||
{"response", `[INST] {{ if .System }}{{ .System }}
|
||
|
||
{{ end }}{{ .Prompt }}[/INST] {{ .Response }}`},
|
||
{"messages", `[INST] {{ if .System }}{{ .System }}
|
||
|
||
{{ end }}
|
||
{{- range .Messages }}
|
||
{{- if eq .Role "user" }}{{ .Content }}[/INST] {{ else if eq .Role "assistant" }}{{ .Content }}[INST] {{ end }}
|
||
{{- end }}`},
|
||
},
|
||
Values{
|
||
Messages: []api.Message{
|
||
{Role: "system", Content: "You are a helpful assistant!"},
|
||
{Role: "user", Content: "Hello friend!"},
|
||
{Role: "assistant", Content: "Hello human!"},
|
||
{Role: "user", Content: "What is your name?"},
|
||
},
|
||
},
|
||
`[INST] You are a helpful assistant!
|
||
|
||
Hello friend![/INST] Hello human![INST] What is your name?[/INST] `,
|
||
},
|
||
{
|
||
"mistral assistant",
|
||
[]template{
|
||
{"no response", `[INST] {{ .Prompt }}[/INST] `},
|
||
{"response", `[INST] {{ .Prompt }}[/INST] {{ .Response }}`},
|
||
{"messages", `
|
||
{{- range $i, $m := .Messages }}
|
||
{{- if eq .Role "user" }}[INST] {{ .Content }}[/INST] {{ else if eq .Role "assistant" }}{{ .Content }}{{ end }}
|
||
{{- end }}`},
|
||
},
|
||
Values{
|
||
Messages: []api.Message{
|
||
{Role: "user", Content: "Hello friend!"},
|
||
{Role: "assistant", Content: "Hello human!"},
|
||
{Role: "user", Content: "What is your name?"},
|
||
{Role: "assistant", Content: "My name is Ollama and I"},
|
||
},
|
||
},
|
||
`[INST] Hello friend![/INST] Hello human![INST] What is your name?[/INST] My name is Ollama and I`,
|
||
},
|
||
{
|
||
"chatml",
|
||
[]template{
|
||
// this does not have a "no response" test because it's impossible to render the same output
|
||
{"response", `{{ if .System }}<|im_start|>system
|
||
{{ .System }}<|im_end|>
|
||
{{ end }}{{ if .Prompt }}<|im_start|>user
|
||
{{ .Prompt }}<|im_end|>
|
||
{{ end }}<|im_start|>assistant
|
||
{{ .Response }}<|im_end|>
|
||
`},
|
||
{"messages", `
|
||
{{- range $index, $_ := .Messages }}<|im_start|>{{ .Role }}
|
||
{{ .Content }}<|im_end|>
|
||
{{ end }}<|im_start|>assistant
|
||
`},
|
||
},
|
||
Values{
|
||
Messages: []api.Message{
|
||
{Role: "system", Content: "You are a helpful assistant!"},
|
||
{Role: "user", Content: "Hello friend!"},
|
||
{Role: "assistant", Content: "Hello human!"},
|
||
{Role: "user", Content: "What is your name?"},
|
||
},
|
||
},
|
||
`<|im_start|>system
|
||
You are a helpful assistant!<|im_end|>
|
||
<|im_start|>user
|
||
Hello friend!<|im_end|>
|
||
<|im_start|>assistant
|
||
Hello human!<|im_end|>
|
||
<|im_start|>user
|
||
What is your name?<|im_end|>
|
||
<|im_start|>assistant
|
||
`,
|
||
},
|
||
}
|
||
|
||
for _, tt := range cases {
|
||
t.Run(tt.name, func(t *testing.T) {
|
||
for _, ttt := range tt.templates {
|
||
t.Run(ttt.name, func(t *testing.T) {
|
||
tmpl, err := Parse(ttt.template)
|
||
if err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
|
||
var b bytes.Buffer
|
||
if err := tmpl.Execute(&b, tt.values); err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
|
||
if diff := cmp.Diff(b.String(), tt.expected); diff != "" {
|
||
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||
}
|
||
})
|
||
}
|
||
})
|
||
}
|
||
}
|
||
|
||
func TestExecuteWithSuffix(t *testing.T) {
|
||
tmpl, err := Parse(`{{- if .Suffix }}<PRE> {{ .Prompt }} <SUF>{{ .Suffix }} <MID>
|
||
{{- else }}{{ .Prompt }}
|
||
{{- end }}`)
|
||
if err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
|
||
cases := []struct {
|
||
name string
|
||
values Values
|
||
expect string
|
||
}{
|
||
{
|
||
"message", Values{Messages: []api.Message{{Role: "user", Content: "hello"}}}, "hello",
|
||
},
|
||
{
|
||
"prompt suffix", Values{Prompt: "def add(", Suffix: "return x"}, "<PRE> def add( <SUF>return x <MID>",
|
||
},
|
||
}
|
||
|
||
for _, tt := range cases {
|
||
t.Run(tt.name, func(t *testing.T) {
|
||
var b bytes.Buffer
|
||
if err := tmpl.Execute(&b, tt.values); err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
|
||
if diff := cmp.Diff(b.String(), tt.expect); diff != "" {
|
||
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||
}
|
||
})
|
||
}
|
||
}
|
||
|
||
func TestCollate(t *testing.T) {
|
||
cases := []struct {
|
||
name string
|
||
msgs []api.Message
|
||
expected []*api.Message
|
||
system string
|
||
}{
|
||
{
|
||
name: "consecutive user messages are merged",
|
||
msgs: []api.Message{
|
||
{Role: "user", Content: "Hello"},
|
||
{Role: "user", Content: "How are you?"},
|
||
},
|
||
expected: []*api.Message{
|
||
{Role: "user", Content: "Hello\n\nHow are you?"},
|
||
},
|
||
system: "",
|
||
},
|
||
{
|
||
name: "consecutive tool messages are NOT merged",
|
||
msgs: []api.Message{
|
||
{Role: "tool", Content: "sunny", ToolName: "get_weather"},
|
||
{Role: "tool", Content: "72F", ToolName: "get_temperature"},
|
||
},
|
||
expected: []*api.Message{
|
||
{Role: "tool", Content: "sunny", ToolName: "get_weather"},
|
||
{Role: "tool", Content: "72F", ToolName: "get_temperature"},
|
||
},
|
||
system: "",
|
||
},
|
||
{
|
||
name: "tool messages preserve all fields",
|
||
msgs: []api.Message{
|
||
{Role: "user", Content: "What's the weather?"},
|
||
{Role: "tool", Content: "sunny", ToolName: "get_conditions"},
|
||
{Role: "tool", Content: "72F", ToolName: "get_temperature"},
|
||
},
|
||
expected: []*api.Message{
|
||
{Role: "user", Content: "What's the weather?"},
|
||
{Role: "tool", Content: "sunny", ToolName: "get_conditions"},
|
||
{Role: "tool", Content: "72F", ToolName: "get_temperature"},
|
||
},
|
||
system: "",
|
||
},
|
||
{
|
||
name: "mixed messages with system",
|
||
msgs: []api.Message{
|
||
{Role: "system", Content: "You are helpful"},
|
||
{Role: "user", Content: "Hello"},
|
||
{Role: "assistant", Content: "Hi there!"},
|
||
{Role: "user", Content: "What's the weather?"},
|
||
{Role: "tool", Content: "sunny", ToolName: "get_weather"},
|
||
{Role: "tool", Content: "72F", ToolName: "get_temperature"},
|
||
{Role: "user", Content: "Thanks"},
|
||
},
|
||
expected: []*api.Message{
|
||
{Role: "system", Content: "You are helpful"},
|
||
{Role: "user", Content: "Hello"},
|
||
{Role: "assistant", Content: "Hi there!"},
|
||
{Role: "user", Content: "What's the weather?"},
|
||
{Role: "tool", Content: "sunny", ToolName: "get_weather"},
|
||
{Role: "tool", Content: "72F", ToolName: "get_temperature"},
|
||
{Role: "user", Content: "Thanks"},
|
||
},
|
||
system: "You are helpful",
|
||
},
|
||
}
|
||
|
||
for _, tt := range cases {
|
||
t.Run(tt.name, func(t *testing.T) {
|
||
system, collated := collate(tt.msgs)
|
||
if diff := cmp.Diff(system, tt.system); diff != "" {
|
||
t.Errorf("system mismatch (-got +want):\n%s", diff)
|
||
}
|
||
|
||
// Compare the messages
|
||
if len(collated) != len(tt.expected) {
|
||
t.Errorf("expected %d messages, got %d", len(tt.expected), len(collated))
|
||
return
|
||
}
|
||
|
||
for i := range collated {
|
||
if collated[i].Role != tt.expected[i].Role {
|
||
t.Errorf("message %d role mismatch: got %q, want %q", i, collated[i].Role, tt.expected[i].Role)
|
||
}
|
||
if collated[i].Content != tt.expected[i].Content {
|
||
t.Errorf("message %d content mismatch: got %q, want %q", i, collated[i].Content, tt.expected[i].Content)
|
||
}
|
||
if collated[i].ToolName != tt.expected[i].ToolName {
|
||
t.Errorf("message %d tool name mismatch: got %q, want %q", i, collated[i].ToolName, tt.expected[i].ToolName)
|
||
}
|
||
}
|
||
})
|
||
}
|
||
}
|