mirror of
https://github.com/jmorganca/ollama
synced 2025-10-06 00:32:49 +02:00
templates: fix crash in improperly defined templates (#12483)
This commit is contained in:
@@ -105,12 +105,16 @@ func (m *Model) Capabilities() []model.Capability {
|
|||||||
|
|
||||||
builtinParser := parsers.ParserForName(m.Config.Parser)
|
builtinParser := parsers.ParserForName(m.Config.Parser)
|
||||||
// Check for tools capability
|
// Check for tools capability
|
||||||
if slices.Contains(m.Template.Vars(), "tools") || (builtinParser != nil && builtinParser.HasToolSupport()) {
|
v, err := m.Template.Vars()
|
||||||
|
if err != nil {
|
||||||
|
slog.Warn("model template contains errors", "error", err)
|
||||||
|
}
|
||||||
|
if slices.Contains(v, "tools") || (builtinParser != nil && builtinParser.HasToolSupport()) {
|
||||||
capabilities = append(capabilities, model.CapabilityTools)
|
capabilities = append(capabilities, model.CapabilityTools)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check for insert capability
|
// Check for insert capability
|
||||||
if slices.Contains(m.Template.Vars(), "suffix") {
|
if slices.Contains(v, "suffix") {
|
||||||
capabilities = append(capabilities, model.CapabilityInsert)
|
capabilities = append(capabilities, model.CapabilityInsert)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -148,7 +148,12 @@ func Parse(s string) (*Template, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
t := Template{Template: tmpl, raw: s}
|
t := Template{Template: tmpl, raw: s}
|
||||||
if vars := t.Vars(); !slices.Contains(vars, "messages") && !slices.Contains(vars, "response") {
|
vars, err := t.Vars()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if !slices.Contains(vars, "messages") && !slices.Contains(vars, "response") {
|
||||||
// touch up the template and append {{ .Response }}
|
// touch up the template and append {{ .Response }}
|
||||||
tmpl.Tree.Root.Nodes = append(tmpl.Tree.Root.Nodes, &response)
|
tmpl.Tree.Root.Nodes = append(tmpl.Tree.Root.Nodes, &response)
|
||||||
}
|
}
|
||||||
@@ -160,11 +165,15 @@ func (t *Template) String() string {
|
|||||||
return t.raw
|
return t.raw
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *Template) Vars() []string {
|
func (t *Template) Vars() ([]string, error) {
|
||||||
var vars []string
|
var vars []string
|
||||||
for _, tt := range t.Templates() {
|
for _, tt := range t.Templates() {
|
||||||
for _, n := range tt.Root.Nodes {
|
for _, n := range tt.Root.Nodes {
|
||||||
vars = append(vars, Identifiers(n)...)
|
v, err := Identifiers(n)
|
||||||
|
if err != nil {
|
||||||
|
return vars, err
|
||||||
|
}
|
||||||
|
vars = append(vars, v...)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -173,7 +182,7 @@ func (t *Template) Vars() []string {
|
|||||||
set[strings.ToLower(n)] = struct{}{}
|
set[strings.ToLower(n)] = struct{}{}
|
||||||
}
|
}
|
||||||
|
|
||||||
return slices.Sorted(maps.Keys(set))
|
return slices.Sorted(maps.Keys(set)), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *Template) Contains(s string) bool {
|
func (t *Template) Contains(s string) bool {
|
||||||
@@ -244,6 +253,10 @@ func (t *Template) Subtree(fn func(parse.Node) bool) *template.Template {
|
|||||||
|
|
||||||
func (t *Template) Execute(w io.Writer, v Values) error {
|
func (t *Template) Execute(w io.Writer, v Values) error {
|
||||||
system, messages := collate(v.Messages)
|
system, messages := collate(v.Messages)
|
||||||
|
vars, err := t.Vars()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
if v.Prompt != "" && v.Suffix != "" {
|
if v.Prompt != "" && v.Suffix != "" {
|
||||||
return t.Template.Execute(w, map[string]any{
|
return t.Template.Execute(w, map[string]any{
|
||||||
"Prompt": v.Prompt,
|
"Prompt": v.Prompt,
|
||||||
@@ -253,7 +266,7 @@ func (t *Template) Execute(w io.Writer, v Values) error {
|
|||||||
"ThinkLevel": v.ThinkLevel,
|
"ThinkLevel": v.ThinkLevel,
|
||||||
"IsThinkSet": v.IsThinkSet,
|
"IsThinkSet": v.IsThinkSet,
|
||||||
})
|
})
|
||||||
} else if !v.forceLegacy && slices.Contains(t.Vars(), "messages") {
|
} else if !v.forceLegacy && slices.Contains(vars, "messages") {
|
||||||
return t.Template.Execute(w, map[string]any{
|
return t.Template.Execute(w, map[string]any{
|
||||||
"System": system,
|
"System": system,
|
||||||
"Messages": messages,
|
"Messages": messages,
|
||||||
@@ -329,7 +342,7 @@ func (t *Template) Execute(w io.Writer, v Values) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err := io.Copy(w, &b)
|
_, err = io.Copy(w, &b)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -358,27 +371,47 @@ func collate(msgs []api.Message) (string, []*api.Message) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Identifiers walks the node tree returning any identifiers it finds along the way
|
// Identifiers walks the node tree returning any identifiers it finds along the way
|
||||||
func Identifiers(n parse.Node) []string {
|
func Identifiers(n parse.Node) ([]string, error) {
|
||||||
switch n := n.(type) {
|
switch n := n.(type) {
|
||||||
case *parse.ListNode:
|
case *parse.ListNode:
|
||||||
var names []string
|
var names []string
|
||||||
for _, n := range n.Nodes {
|
for _, n := range n.Nodes {
|
||||||
names = append(names, Identifiers(n)...)
|
i, err := Identifiers(n)
|
||||||
|
if err != nil {
|
||||||
|
return names, err
|
||||||
|
}
|
||||||
|
names = append(names, i...)
|
||||||
}
|
}
|
||||||
|
|
||||||
return names
|
return names, nil
|
||||||
case *parse.TemplateNode:
|
case *parse.TemplateNode:
|
||||||
|
if n.Pipe == nil {
|
||||||
|
return nil, errors.New("undefined template specified")
|
||||||
|
}
|
||||||
return Identifiers(n.Pipe)
|
return Identifiers(n.Pipe)
|
||||||
case *parse.ActionNode:
|
case *parse.ActionNode:
|
||||||
|
if n.Pipe == nil {
|
||||||
|
return nil, errors.New("undefined action in template")
|
||||||
|
}
|
||||||
return Identifiers(n.Pipe)
|
return Identifiers(n.Pipe)
|
||||||
case *parse.BranchNode:
|
case *parse.BranchNode:
|
||||||
names := Identifiers(n.Pipe)
|
if n.Pipe == nil {
|
||||||
|
return nil, errors.New("undefined branch")
|
||||||
|
}
|
||||||
|
names, err := Identifiers(n.Pipe)
|
||||||
|
if err != nil {
|
||||||
|
return names, err
|
||||||
|
}
|
||||||
for _, n := range []*parse.ListNode{n.List, n.ElseList} {
|
for _, n := range []*parse.ListNode{n.List, n.ElseList} {
|
||||||
if n != nil {
|
if n != nil {
|
||||||
names = append(names, Identifiers(n)...)
|
i, err := Identifiers(n)
|
||||||
|
if err != nil {
|
||||||
|
return names, err
|
||||||
|
}
|
||||||
|
names = append(names, i...)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return names
|
return names, nil
|
||||||
case *parse.IfNode:
|
case *parse.IfNode:
|
||||||
return Identifiers(&n.BranchNode)
|
return Identifiers(&n.BranchNode)
|
||||||
case *parse.RangeNode:
|
case *parse.RangeNode:
|
||||||
@@ -389,17 +422,21 @@ func Identifiers(n parse.Node) []string {
|
|||||||
var names []string
|
var names []string
|
||||||
for _, c := range n.Cmds {
|
for _, c := range n.Cmds {
|
||||||
for _, a := range c.Args {
|
for _, a := range c.Args {
|
||||||
names = append(names, Identifiers(a)...)
|
i, err := Identifiers(a)
|
||||||
|
if err != nil {
|
||||||
|
return names, err
|
||||||
|
}
|
||||||
|
names = append(names, i...)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return names
|
return names, nil
|
||||||
case *parse.FieldNode:
|
case *parse.FieldNode:
|
||||||
return n.Ident
|
return n.Ident, nil
|
||||||
case *parse.VariableNode:
|
case *parse.VariableNode:
|
||||||
return n.Ident
|
return n.Ident, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// deleteNode walks the node list and deletes nodes that match the predicate
|
// deleteNode walks the node list and deletes nodes that match the predicate
|
||||||
|
@@ -192,7 +192,11 @@ func TestParse(t *testing.T) {
|
|||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if diff := cmp.Diff(tmpl.Vars(), tt.vars); diff != "" {
|
v, err := tmpl.Vars()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if diff := cmp.Diff(v, tt.vars); diff != "" {
|
||||||
t.Errorf("mismatch (-got +want):\n%s", diff)
|
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
Reference in New Issue
Block a user