templates: fix crash in improperly defined templates (#12483)

This commit is contained in:
Patrick Devine
2025-10-02 17:25:55 -07:00
committed by GitHub
parent 0bda72892c
commit 1ed2881ef0
3 changed files with 65 additions and 20 deletions

View File

@@ -105,12 +105,16 @@ func (m *Model) Capabilities() []model.Capability {
builtinParser := parsers.ParserForName(m.Config.Parser)
// Check for tools capability
if slices.Contains(m.Template.Vars(), "tools") || (builtinParser != nil && builtinParser.HasToolSupport()) {
v, err := m.Template.Vars()
if err != nil {
slog.Warn("model template contains errors", "error", err)
}
if slices.Contains(v, "tools") || (builtinParser != nil && builtinParser.HasToolSupport()) {
capabilities = append(capabilities, model.CapabilityTools)
}
// Check for insert capability
if slices.Contains(m.Template.Vars(), "suffix") {
if slices.Contains(v, "suffix") {
capabilities = append(capabilities, model.CapabilityInsert)
}

View File

@@ -148,7 +148,12 @@ func Parse(s string) (*Template, error) {
}
t := Template{Template: tmpl, raw: s}
if vars := t.Vars(); !slices.Contains(vars, "messages") && !slices.Contains(vars, "response") {
vars, err := t.Vars()
if err != nil {
return nil, err
}
if !slices.Contains(vars, "messages") && !slices.Contains(vars, "response") {
// touch up the template and append {{ .Response }}
tmpl.Tree.Root.Nodes = append(tmpl.Tree.Root.Nodes, &response)
}
@@ -160,11 +165,15 @@ func (t *Template) String() string {
return t.raw
}
func (t *Template) Vars() []string {
func (t *Template) Vars() ([]string, error) {
var vars []string
for _, tt := range t.Templates() {
for _, n := range tt.Root.Nodes {
vars = append(vars, Identifiers(n)...)
v, err := Identifiers(n)
if err != nil {
return vars, err
}
vars = append(vars, v...)
}
}
@@ -173,7 +182,7 @@ func (t *Template) Vars() []string {
set[strings.ToLower(n)] = struct{}{}
}
return slices.Sorted(maps.Keys(set))
return slices.Sorted(maps.Keys(set)), nil
}
func (t *Template) Contains(s string) bool {
@@ -244,6 +253,10 @@ func (t *Template) Subtree(fn func(parse.Node) bool) *template.Template {
func (t *Template) Execute(w io.Writer, v Values) error {
system, messages := collate(v.Messages)
vars, err := t.Vars()
if err != nil {
return err
}
if v.Prompt != "" && v.Suffix != "" {
return t.Template.Execute(w, map[string]any{
"Prompt": v.Prompt,
@@ -253,7 +266,7 @@ func (t *Template) Execute(w io.Writer, v Values) error {
"ThinkLevel": v.ThinkLevel,
"IsThinkSet": v.IsThinkSet,
})
} else if !v.forceLegacy && slices.Contains(t.Vars(), "messages") {
} else if !v.forceLegacy && slices.Contains(vars, "messages") {
return t.Template.Execute(w, map[string]any{
"System": system,
"Messages": messages,
@@ -329,7 +342,7 @@ func (t *Template) Execute(w io.Writer, v Values) error {
return err
}
_, err := io.Copy(w, &b)
_, err = io.Copy(w, &b)
return err
}
@@ -358,27 +371,47 @@ func collate(msgs []api.Message) (string, []*api.Message) {
}
// Identifiers walks the node tree returning any identifiers it finds along the way
func Identifiers(n parse.Node) []string {
func Identifiers(n parse.Node) ([]string, error) {
switch n := n.(type) {
case *parse.ListNode:
var names []string
for _, n := range n.Nodes {
names = append(names, Identifiers(n)...)
i, err := Identifiers(n)
if err != nil {
return names, err
}
names = append(names, i...)
}
return names
return names, nil
case *parse.TemplateNode:
if n.Pipe == nil {
return nil, errors.New("undefined template specified")
}
return Identifiers(n.Pipe)
case *parse.ActionNode:
if n.Pipe == nil {
return nil, errors.New("undefined action in template")
}
return Identifiers(n.Pipe)
case *parse.BranchNode:
names := Identifiers(n.Pipe)
if n.Pipe == nil {
return nil, errors.New("undefined branch")
}
names, err := Identifiers(n.Pipe)
if err != nil {
return names, err
}
for _, n := range []*parse.ListNode{n.List, n.ElseList} {
if n != nil {
names = append(names, Identifiers(n)...)
i, err := Identifiers(n)
if err != nil {
return names, err
}
names = append(names, i...)
}
}
return names
return names, nil
case *parse.IfNode:
return Identifiers(&n.BranchNode)
case *parse.RangeNode:
@@ -389,17 +422,21 @@ func Identifiers(n parse.Node) []string {
var names []string
for _, c := range n.Cmds {
for _, a := range c.Args {
names = append(names, Identifiers(a)...)
i, err := Identifiers(a)
if err != nil {
return names, err
}
names = append(names, i...)
}
}
return names
return names, nil
case *parse.FieldNode:
return n.Ident
return n.Ident, nil
case *parse.VariableNode:
return n.Ident
return n.Ident, nil
}
return nil
return nil, nil
}
// deleteNode walks the node list and deletes nodes that match the predicate

View File

@@ -192,7 +192,11 @@ func TestParse(t *testing.T) {
t.Fatal(err)
}
if diff := cmp.Diff(tmpl.Vars(), tt.vars); diff != "" {
v, err := tmpl.Vars()
if err != nil {
t.Fatal(err)
}
if diff := cmp.Diff(v, tt.vars); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
})