mirror of
https://github.com/jmorganca/ollama
synced 2025-10-06 00:32:49 +02:00
130 lines
3.0 KiB
Go
130 lines
3.0 KiB
Go
package convert
|
|
|
|
import (
|
|
"cmp"
|
|
"io"
|
|
"iter"
|
|
"path"
|
|
"slices"
|
|
"strings"
|
|
|
|
"github.com/pdevine/tensor"
|
|
"github.com/pdevine/tensor/native"
|
|
|
|
"github.com/ollama/ollama/fs/ggml"
|
|
)
|
|
|
|
type split struct {
|
|
*strings.Replacer
|
|
dim int
|
|
|
|
// fn is an optional function to apply to the tensor after slicing
|
|
fn func(tensor.Tensor) (tensor.Tensor, error)
|
|
}
|
|
|
|
// splitDim splits a tensor along a specified dimension into multiple tensors. The dimension
|
|
// is split evenly based on the number of replacers provided unless a specific count is given.
|
|
func splitDim(t Tensor, dim int, splits ...split) iter.Seq[*ggml.Tensor] {
|
|
return func(yield func(*ggml.Tensor) bool) {
|
|
var offset int
|
|
for _, split := range splits {
|
|
t := t.Clone()
|
|
shape := slices.Clone(t.Shape())
|
|
shape[dim] = cmp.Or(uint64(split.dim), shape[dim]/uint64(len(splits)))
|
|
|
|
slice := slices.Repeat([]tensor.Slice{nil}, len(shape))
|
|
slice[dim] = tensor.S(offset, offset+int(shape[dim]))
|
|
offset += int(shape[dim])
|
|
|
|
t.SetRepacker(func(_ string, data []float32, shape []uint64) ([]float32, error) {
|
|
dims := make([]int, len(shape))
|
|
for i := range shape {
|
|
dims[i] = int(shape[i])
|
|
}
|
|
|
|
var tt tensor.Tensor = tensor.New(tensor.WithShape(dims...), tensor.WithBacking(data))
|
|
tt, err := tt.Slice(slice...)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
tt = tensor.Materialize(tt)
|
|
|
|
if split.fn != nil {
|
|
tt, err = split.fn(tt)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
// flatten tensor so it can be written as a vector
|
|
if err := tt.Reshape(tt.Shape().TotalSize()); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return native.VectorF32(tt.(*tensor.Dense))
|
|
})
|
|
|
|
if !yield(&ggml.Tensor{
|
|
Name: split.Replace(t.Name()),
|
|
Kind: t.Kind(),
|
|
Shape: shape,
|
|
WriterTo: t,
|
|
}) {
|
|
break
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
type merge struct {
|
|
pattern, name string
|
|
}
|
|
|
|
// mergeTensors merges tensors that match a given pattern into a single tensor.
|
|
func mergeTensors(unmatched []Tensor, merges ...merge) (out []*ggml.Tensor, _ []Tensor) {
|
|
var matched []Tensor
|
|
for i := range merges {
|
|
matched, unmatched = slicesSplitFunc(unmatched, func(t Tensor) bool {
|
|
matched, _ := path.Match(merges[i].pattern, t.Name())
|
|
return matched
|
|
})
|
|
|
|
if len(matched) > 0 {
|
|
out = append(out, &ggml.Tensor{
|
|
Name: merges[i].name,
|
|
Kind: matched[0].Kind(),
|
|
Shape: append([]uint64{uint64(len(matched))}, matched[0].Shape()...),
|
|
WriterTo: mergeGroup(matched),
|
|
})
|
|
}
|
|
}
|
|
|
|
return out, unmatched
|
|
}
|
|
|
|
// slicesSplitFunc splits a slice into two slices based on a predicate function.
|
|
func slicesSplitFunc[S ~[]E, E comparable](s S, fn func(e E) bool) (matched, unmatched S) {
|
|
for _, e := range s {
|
|
if fn(e) {
|
|
matched = append(matched, e)
|
|
} else {
|
|
unmatched = append(unmatched, e)
|
|
}
|
|
}
|
|
|
|
return matched, unmatched
|
|
}
|
|
|
|
type mergeGroup []Tensor
|
|
|
|
func (g mergeGroup) WriteTo(w io.Writer) (int64, error) {
|
|
for _, t := range g {
|
|
if _, err := t.WriteTo(w); err != nil {
|
|
return 0, err
|
|
}
|
|
}
|
|
|
|
return 0, nil
|
|
}
|