Files
ollama/convert/tensor.go
2025-06-20 11:12:01 -07:00

130 lines
3.0 KiB
Go

package convert
import (
"cmp"
"io"
"iter"
"path"
"slices"
"strings"
"github.com/pdevine/tensor"
"github.com/pdevine/tensor/native"
"github.com/ollama/ollama/fs/ggml"
)
type split struct {
*strings.Replacer
dim int
// fn is an optional function to apply to the tensor after slicing
fn func(tensor.Tensor) (tensor.Tensor, error)
}
// splitDim splits a tensor along a specified dimension into multiple tensors. The dimension
// is split evenly based on the number of replacers provided unless a specific count is given.
func splitDim(t Tensor, dim int, splits ...split) iter.Seq[*ggml.Tensor] {
return func(yield func(*ggml.Tensor) bool) {
var offset int
for _, split := range splits {
t := t.Clone()
shape := slices.Clone(t.Shape())
shape[dim] = cmp.Or(uint64(split.dim), shape[dim]/uint64(len(splits)))
slice := slices.Repeat([]tensor.Slice{nil}, len(shape))
slice[dim] = tensor.S(offset, offset+int(shape[dim]))
offset += int(shape[dim])
t.SetRepacker(func(_ string, data []float32, shape []uint64) ([]float32, error) {
dims := make([]int, len(shape))
for i := range shape {
dims[i] = int(shape[i])
}
var tt tensor.Tensor = tensor.New(tensor.WithShape(dims...), tensor.WithBacking(data))
tt, err := tt.Slice(slice...)
if err != nil {
return nil, err
}
tt = tensor.Materialize(tt)
if split.fn != nil {
tt, err = split.fn(tt)
if err != nil {
return nil, err
}
}
// flatten tensor so it can be written as a vector
if err := tt.Reshape(tt.Shape().TotalSize()); err != nil {
return nil, err
}
return native.VectorF32(tt.(*tensor.Dense))
})
if !yield(&ggml.Tensor{
Name: split.Replace(t.Name()),
Kind: t.Kind(),
Shape: shape,
WriterTo: t,
}) {
break
}
}
}
}
type merge struct {
pattern, name string
}
// mergeTensors merges tensors that match a given pattern into a single tensor.
func mergeTensors(unmatched []Tensor, merges ...merge) (out []*ggml.Tensor, _ []Tensor) {
var matched []Tensor
for i := range merges {
matched, unmatched = slicesSplitFunc(unmatched, func(t Tensor) bool {
matched, _ := path.Match(merges[i].pattern, t.Name())
return matched
})
if len(matched) > 0 {
out = append(out, &ggml.Tensor{
Name: merges[i].name,
Kind: matched[0].Kind(),
Shape: append([]uint64{uint64(len(matched))}, matched[0].Shape()...),
WriterTo: mergeGroup(matched),
})
}
}
return out, unmatched
}
// slicesSplitFunc splits a slice into two slices based on a predicate function.
func slicesSplitFunc[S ~[]E, E comparable](s S, fn func(e E) bool) (matched, unmatched S) {
for _, e := range s {
if fn(e) {
matched = append(matched, e)
} else {
unmatched = append(unmatched, e)
}
}
return matched, unmatched
}
type mergeGroup []Tensor
func (g mergeGroup) WriteTo(w io.Writer) (int64, error) {
for _, t := range g {
if _, err := t.WriteTo(w); err != nil {
return 0, err
}
}
return 0, nil
}