templates: fix crash in improperly defined templates (#12483)

This commit is contained in:
Patrick Devine
2025-10-02 17:25:55 -07:00
committed by GitHub
parent 0bda72892c
commit 1ed2881ef0
3 changed files with 65 additions and 20 deletions

View File

@@ -105,12 +105,16 @@ func (m *Model) Capabilities() []model.Capability {
builtinParser := parsers.ParserForName(m.Config.Parser) builtinParser := parsers.ParserForName(m.Config.Parser)
// Check for tools capability // Check for tools capability
if slices.Contains(m.Template.Vars(), "tools") || (builtinParser != nil && builtinParser.HasToolSupport()) { v, err := m.Template.Vars()
if err != nil {
slog.Warn("model template contains errors", "error", err)
}
if slices.Contains(v, "tools") || (builtinParser != nil && builtinParser.HasToolSupport()) {
capabilities = append(capabilities, model.CapabilityTools) capabilities = append(capabilities, model.CapabilityTools)
} }
// Check for insert capability // Check for insert capability
if slices.Contains(m.Template.Vars(), "suffix") { if slices.Contains(v, "suffix") {
capabilities = append(capabilities, model.CapabilityInsert) capabilities = append(capabilities, model.CapabilityInsert)
} }

View File

@@ -148,7 +148,12 @@ func Parse(s string) (*Template, error) {
} }
t := Template{Template: tmpl, raw: s} t := Template{Template: tmpl, raw: s}
if vars := t.Vars(); !slices.Contains(vars, "messages") && !slices.Contains(vars, "response") { vars, err := t.Vars()
if err != nil {
return nil, err
}
if !slices.Contains(vars, "messages") && !slices.Contains(vars, "response") {
// touch up the template and append {{ .Response }} // touch up the template and append {{ .Response }}
tmpl.Tree.Root.Nodes = append(tmpl.Tree.Root.Nodes, &response) tmpl.Tree.Root.Nodes = append(tmpl.Tree.Root.Nodes, &response)
} }
@@ -160,11 +165,15 @@ func (t *Template) String() string {
return t.raw return t.raw
} }
func (t *Template) Vars() []string { func (t *Template) Vars() ([]string, error) {
var vars []string var vars []string
for _, tt := range t.Templates() { for _, tt := range t.Templates() {
for _, n := range tt.Root.Nodes { for _, n := range tt.Root.Nodes {
vars = append(vars, Identifiers(n)...) v, err := Identifiers(n)
if err != nil {
return vars, err
}
vars = append(vars, v...)
} }
} }
@@ -173,7 +182,7 @@ func (t *Template) Vars() []string {
set[strings.ToLower(n)] = struct{}{} set[strings.ToLower(n)] = struct{}{}
} }
return slices.Sorted(maps.Keys(set)) return slices.Sorted(maps.Keys(set)), nil
} }
func (t *Template) Contains(s string) bool { func (t *Template) Contains(s string) bool {
@@ -244,6 +253,10 @@ func (t *Template) Subtree(fn func(parse.Node) bool) *template.Template {
func (t *Template) Execute(w io.Writer, v Values) error { func (t *Template) Execute(w io.Writer, v Values) error {
system, messages := collate(v.Messages) system, messages := collate(v.Messages)
vars, err := t.Vars()
if err != nil {
return err
}
if v.Prompt != "" && v.Suffix != "" { if v.Prompt != "" && v.Suffix != "" {
return t.Template.Execute(w, map[string]any{ return t.Template.Execute(w, map[string]any{
"Prompt": v.Prompt, "Prompt": v.Prompt,
@@ -253,7 +266,7 @@ func (t *Template) Execute(w io.Writer, v Values) error {
"ThinkLevel": v.ThinkLevel, "ThinkLevel": v.ThinkLevel,
"IsThinkSet": v.IsThinkSet, "IsThinkSet": v.IsThinkSet,
}) })
} else if !v.forceLegacy && slices.Contains(t.Vars(), "messages") { } else if !v.forceLegacy && slices.Contains(vars, "messages") {
return t.Template.Execute(w, map[string]any{ return t.Template.Execute(w, map[string]any{
"System": system, "System": system,
"Messages": messages, "Messages": messages,
@@ -329,7 +342,7 @@ func (t *Template) Execute(w io.Writer, v Values) error {
return err return err
} }
_, err := io.Copy(w, &b) _, err = io.Copy(w, &b)
return err return err
} }
@@ -358,27 +371,47 @@ func collate(msgs []api.Message) (string, []*api.Message) {
} }
// Identifiers walks the node tree returning any identifiers it finds along the way // Identifiers walks the node tree returning any identifiers it finds along the way
func Identifiers(n parse.Node) []string { func Identifiers(n parse.Node) ([]string, error) {
switch n := n.(type) { switch n := n.(type) {
case *parse.ListNode: case *parse.ListNode:
var names []string var names []string
for _, n := range n.Nodes { for _, n := range n.Nodes {
names = append(names, Identifiers(n)...) i, err := Identifiers(n)
if err != nil {
return names, err
}
names = append(names, i...)
} }
return names return names, nil
case *parse.TemplateNode: case *parse.TemplateNode:
if n.Pipe == nil {
return nil, errors.New("undefined template specified")
}
return Identifiers(n.Pipe) return Identifiers(n.Pipe)
case *parse.ActionNode: case *parse.ActionNode:
if n.Pipe == nil {
return nil, errors.New("undefined action in template")
}
return Identifiers(n.Pipe) return Identifiers(n.Pipe)
case *parse.BranchNode: case *parse.BranchNode:
names := Identifiers(n.Pipe) if n.Pipe == nil {
return nil, errors.New("undefined branch")
}
names, err := Identifiers(n.Pipe)
if err != nil {
return names, err
}
for _, n := range []*parse.ListNode{n.List, n.ElseList} { for _, n := range []*parse.ListNode{n.List, n.ElseList} {
if n != nil { if n != nil {
names = append(names, Identifiers(n)...) i, err := Identifiers(n)
if err != nil {
return names, err
}
names = append(names, i...)
} }
} }
return names return names, nil
case *parse.IfNode: case *parse.IfNode:
return Identifiers(&n.BranchNode) return Identifiers(&n.BranchNode)
case *parse.RangeNode: case *parse.RangeNode:
@@ -389,17 +422,21 @@ func Identifiers(n parse.Node) []string {
var names []string var names []string
for _, c := range n.Cmds { for _, c := range n.Cmds {
for _, a := range c.Args { for _, a := range c.Args {
names = append(names, Identifiers(a)...) i, err := Identifiers(a)
if err != nil {
return names, err
}
names = append(names, i...)
} }
} }
return names return names, nil
case *parse.FieldNode: case *parse.FieldNode:
return n.Ident return n.Ident, nil
case *parse.VariableNode: case *parse.VariableNode:
return n.Ident return n.Ident, nil
} }
return nil return nil, nil
} }
// deleteNode walks the node list and deletes nodes that match the predicate // deleteNode walks the node list and deletes nodes that match the predicate

View File

@@ -192,7 +192,11 @@ func TestParse(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
if diff := cmp.Diff(tmpl.Vars(), tt.vars); diff != "" { v, err := tmpl.Vars()
if err != nil {
t.Fatal(err)
}
if diff := cmp.Diff(v, tt.vars); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff) t.Errorf("mismatch (-got +want):\n%s", diff)
} }
}) })