Files
ollama/convert/tensor_test.go
Michael Yang fa7776fd24 gpt-oss (#11672)
* bf16

* tests

* gpt-oss

* enable gptoss for engine

* rough estimate

* convert to mxfp4

* handle safetensors U8

* clamp glu/linear

* update tokenizer

* MXFP4 support

This implements the Open Compute Microscaling (MX) FP4 format
as a tensor type with backend implementations focusing
on mulmat and mulmatid on CPU, CUDA, and Metal.

* Unit tests for MXFP4 support

This exercises various operations and shapes on both CPU and GPU (if detected
on the system)

* cuda graph

* unit test adjustments

* cuda: optimize memory access

Read 4 bytes at a time (8 elements) when performing mul_mat_vec_mxfp4

* mac: fix crash on old macos versions

cblas_sgemm is only supported on v13.3 and up, however bf16 is
only supported on v14+ so we were falling back to ggml-blas and
crashing on bf16 tensors.  Checking for the function being null
seems to be the simplest way to condittionally avoid registering the
backend.

* server: Minimum context length for gptoss

This model requires a minimum context length of 8192 to function
effectively. Users can set higher values through all normal mechanisms
but lower values will be silently reset.

* ggml: Multiply by numParallel for gptoss sliding window

When computing the graph size estimate, the context size is already
multiplied by numParallel so estimates reflect that. However, since
sliding window models use a smaller, fixed context size, they need
to manually take numParallel into account.

* gpt-oss integration

includes harmony parser and thinking levels, etc.

* fix sync

* fix tests

* fix lint

---------

Co-authored-by: Daniel Hiltgen <daniel@ollama.com>
Co-authored-by: Jesse Gross <jesse@ollama.com>
Co-authored-by: Devon Rifkin <drifkin@drifkin.net>
2025-08-05 12:21:16 -07:00

954 lines
22 KiB
Go

package convert
import (
"bytes"
"encoding/binary"
"io"
"iter"
"slices"
"strings"
"testing"
"github.com/google/go-cmp/cmp"
"github.com/ollama/ollama/fs/ggml"
"github.com/pdevine/tensor"
)
type fakeTensor struct {
name string
shape []uint64
data []float32
repacker Repacker
}
func (f fakeTensor) Name() string {
return f.name
}
func (f fakeTensor) Shape() []uint64 {
return f.shape
}
func (f fakeTensor) Kind() uint32 {
return 0
}
func (f *fakeTensor) SetRepacker(fn Repacker) {
f.repacker = fn
}
func (f fakeTensor) Clone() Tensor {
return &fakeTensor{
name: f.name,
shape: slices.Clone(f.shape),
data: slices.Clone(f.data),
repacker: f.repacker,
}
}
func (f fakeTensor) WriteTo(w io.Writer) (n int64, err error) {
data := f.data
if f.repacker != nil {
data, err = f.repacker(f.name, data, f.shape)
if err != nil {
return 0, err
}
}
if err := binary.Write(w, binary.LittleEndian, data); err != nil {
return 0, err
}
return int64(len(data) * 4), nil
}
func mul(shape []uint64) int {
n := 1
for _, dim := range shape {
n *= int(dim)
}
return n
}
func TestSplitDim(t *testing.T) {
t.Run("2d", func(t *testing.T) {
r := fakeTensor{
name: "a.b",
shape: []uint64{3, 4},
data: []float32{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11},
}
t.Run("no split", func(t *testing.T) {
for tt := range splitDim(&r, 0, split{Replacer: strings.NewReplacer("a", "x")}) {
if tt.Name != "x.b" {
t.Fatalf("expected name 'x', got '%s'", tt.Name)
}
if diff := cmp.Diff(tt.Shape, []uint64{3, 4}); diff != "" {
t.Errorf("unexpected shape (-want +got):\n%s", diff)
}
var b bytes.Buffer
if _, err := tt.WriteTo(&b); err != nil {
t.Fatal(err)
}
f32s := make([]float32, mul(tt.Shape))
if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil {
t.Fatal(err)
}
if diff := cmp.Diff(f32s, []float32{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}); diff != "" {
t.Errorf("unexpected data (-want +got):\n%s", diff)
}
}
})
t.Run("even split", func(t *testing.T) {
next, stop := iter.Pull(splitDim(&r, 1,
split{Replacer: strings.NewReplacer("a", "x")},
split{Replacer: strings.NewReplacer("b", "y")},
))
defer stop()
{
tt, ok := next()
if !ok {
t.Fatal("expected at least one split")
}
if tt.Name != "x.b" {
t.Fatal("expected name 'x.b', got", tt.Name)
}
if diff := cmp.Diff(tt.Shape, []uint64{3, 2}); diff != "" {
t.Errorf("unexpected shape (-want +got):\n%s", diff)
}
var b bytes.Buffer
if _, err := tt.WriteTo(&b); err != nil {
t.Fatal(err)
}
f32s := make([]float32, mul(tt.Shape))
if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil {
t.Fatal(err)
}
if diff := cmp.Diff(f32s, []float32{0, 1, 4, 5, 8, 9}); diff != "" {
t.Errorf("unexpected data (-want +got):\n%s", diff)
}
}
{
tt, ok := next()
if !ok {
t.Fatal("expected at least one split")
}
if tt.Name != "a.y" {
t.Fatal("expected name 'a.y', got", tt.Name)
}
if diff := cmp.Diff(tt.Shape, []uint64{3, 2}); diff != "" {
t.Errorf("unexpected shape (-want +got):\n%s", diff)
}
var b bytes.Buffer
if _, err := tt.WriteTo(&b); err != nil {
t.Fatal(err)
}
f32s := make([]float32, mul(tt.Shape))
if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil {
t.Fatal(err)
}
if diff := cmp.Diff(f32s, []float32{2, 3, 6, 7, 10, 11}); diff != "" {
t.Errorf("unexpected data (-want +got):\n%s", diff)
}
}
})
t.Run("uneven split", func(t *testing.T) {
next, stop := iter.Pull(splitDim(&r, 0,
split{Replacer: strings.NewReplacer("a", "x"), dim: 2},
split{Replacer: strings.NewReplacer("b", "y"), dim: 1},
))
defer stop()
{
tt, ok := next()
if !ok {
t.Fatal("expected at least one split")
}
if tt.Name != "x.b" {
t.Fatal("expected name 'x.b', got", tt.Name)
}
if diff := cmp.Diff(tt.Shape, []uint64{2, 4}); diff != "" {
t.Errorf("unexpected shape (-want +got):\n%s", diff)
}
var b bytes.Buffer
if _, err := tt.WriteTo(&b); err != nil {
t.Fatal(err)
}
f32s := make([]float32, mul(tt.Shape))
if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil {
t.Fatal(err)
}
if diff := cmp.Diff(f32s, []float32{0, 1, 2, 3, 4, 5, 6, 7}); diff != "" {
t.Errorf("unexpected data (-want +got):\n%s", diff)
}
}
{
tt, ok := next()
if !ok {
t.Fatal("expected at least one split")
}
if tt.Name != "a.y" {
t.Fatal("expected name 'a.y', got", tt.Name)
}
if diff := cmp.Diff(tt.Shape, []uint64{1, 4}); diff != "" {
t.Errorf("unexpected shape (-want +got):\n%s", diff)
}
var b bytes.Buffer
if _, err := tt.WriteTo(&b); err != nil {
t.Fatal(err)
}
f32s := make([]float32, mul(tt.Shape))
if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil {
t.Fatal(err)
}
if diff := cmp.Diff(f32s, []float32{8, 9, 10, 11}); diff != "" {
t.Errorf("unexpected data (-want +got):\n%s", diff)
}
}
})
t.Run("three way split", func(t *testing.T) {
next, stop := iter.Pull(splitDim(&r, 0,
split{Replacer: strings.NewReplacer("a", "x"), dim: 1},
split{Replacer: strings.NewReplacer("b", "y"), dim: 1},
split{Replacer: strings.NewReplacer("b", "z"), dim: 1},
))
defer stop()
{
tt, ok := next()
if !ok {
t.Fatal("expected at least one split")
}
if tt.Name != "x.b" {
t.Fatal("expected name 'x.b', got", tt.Name)
}
if diff := cmp.Diff(tt.Shape, []uint64{1, 4}); diff != "" {
t.Errorf("unexpected shape (-want +got):\n%s", diff)
}
var b bytes.Buffer
if _, err := tt.WriteTo(&b); err != nil {
t.Fatal(err)
}
f32s := make([]float32, mul(tt.Shape))
if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil {
t.Fatal(err)
}
if diff := cmp.Diff(f32s, []float32{0, 1, 2, 3}); diff != "" {
t.Errorf("unexpected data (-want +got):\n%s", diff)
}
}
{
tt, ok := next()
if !ok {
t.Fatal("expected at least one split")
}
if tt.Name != "a.y" {
t.Fatal("expected name 'x.b', got", tt.Name)
}
if diff := cmp.Diff(tt.Shape, []uint64{1, 4}); diff != "" {
t.Errorf("unexpected shape (-want +got):\n%s", diff)
}
var b bytes.Buffer
if _, err := tt.WriteTo(&b); err != nil {
t.Fatal(err)
}
f32s := make([]float32, mul(tt.Shape))
if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil {
t.Fatal(err)
}
if diff := cmp.Diff(f32s, []float32{4, 5, 6, 7}); diff != "" {
t.Errorf("unexpected data (-want +got):\n%s", diff)
}
}
{
tt, ok := next()
if !ok {
t.Fatal("expected at least one split")
}
if tt.Name != "a.z" {
t.Fatal("expected name 'x.b', got", tt.Name)
}
if diff := cmp.Diff(tt.Shape, []uint64{1, 4}); diff != "" {
t.Errorf("unexpected shape (-want +got):\n%s", diff)
}
var b bytes.Buffer
if _, err := tt.WriteTo(&b); err != nil {
t.Fatal(err)
}
f32s := make([]float32, mul(tt.Shape))
if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil {
t.Fatal(err)
}
if diff := cmp.Diff(f32s, []float32{8, 9, 10, 11}); diff != "" {
t.Errorf("unexpected data (-want +got):\n%s", diff)
}
}
})
t.Run("uneven three way split", func(t *testing.T) {
next, stop := iter.Pull(splitDim(&r, 1,
split{Replacer: strings.NewReplacer("a", "x"), dim: 2},
split{Replacer: strings.NewReplacer("b", "y"), dim: 1},
split{Replacer: strings.NewReplacer("b", "z"), dim: 1},
))
defer stop()
{
tt, ok := next()
if !ok {
t.Fatal("expected at least one split")
}
if tt.Name != "x.b" {
t.Fatal("expected name 'x.b', got", tt.Name)
}
if diff := cmp.Diff(tt.Shape, []uint64{3, 2}); diff != "" {
t.Errorf("unexpected shape (-want +got):\n%s", diff)
}
var b bytes.Buffer
if _, err := tt.WriteTo(&b); err != nil {
t.Fatal(err)
}
f32s := make([]float32, mul(tt.Shape))
if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil {
t.Fatal(err)
}
if diff := cmp.Diff(f32s, []float32{0, 1, 4, 5, 8, 9}); diff != "" {
t.Errorf("unexpected data (-want +got):\n%s", diff)
}
}
{
tt, ok := next()
if !ok {
t.Fatal("expected at least one split")
}
if tt.Name != "a.y" {
t.Fatal("expected name 'x.b', got", tt.Name)
}
if diff := cmp.Diff(tt.Shape, []uint64{3, 1}); diff != "" {
t.Errorf("unexpected shape (-want +got):\n%s", diff)
}
var b bytes.Buffer
if _, err := tt.WriteTo(&b); err != nil {
t.Fatal(err)
}
f32s := make([]float32, mul(tt.Shape))
if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil {
t.Fatal(err)
}
if diff := cmp.Diff(f32s, []float32{2, 6, 10}); diff != "" {
t.Errorf("unexpected data (-want +got):\n%s", diff)
}
}
{
tt, ok := next()
if !ok {
t.Fatal("expected at least one split")
}
if tt.Name != "a.z" {
t.Fatal("expected name 'x.b', got", tt.Name)
}
if diff := cmp.Diff(tt.Shape, []uint64{3, 1}); diff != "" {
t.Errorf("unexpected shape (-want +got):\n%s", diff)
}
var b bytes.Buffer
if _, err := tt.WriteTo(&b); err != nil {
t.Fatal(err)
}
f32s := make([]float32, mul(tt.Shape))
if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil {
t.Fatal(err)
}
if diff := cmp.Diff(f32s, []float32{3, 7, 11}); diff != "" {
t.Errorf("unexpected data (-want +got):\n%s", diff)
}
}
})
t.Run("split with transpose", func(t *testing.T) {
next, stop := iter.Pull(splitDim(&r, 1,
split{Replacer: strings.NewReplacer("a", "x")},
split{Replacer: strings.NewReplacer("b", "y"), fn: func(tt tensor.Tensor) (tensor.Tensor, error) {
return tensor.Transpose(tt, 1, 0)
}},
))
defer stop()
{
tt, ok := next()
if !ok {
t.Fatal("expected at least one split")
}
if tt.Name != "x.b" {
t.Fatal("expected name 'x.b', got", tt.Name)
}
if diff := cmp.Diff(tt.Shape, []uint64{3, 2}); diff != "" {
t.Errorf("unexpected shape (-want +got):\n%s", diff)
}
var b bytes.Buffer
if _, err := tt.WriteTo(&b); err != nil {
t.Fatal(err)
}
f32s := make([]float32, mul(tt.Shape))
if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil {
t.Fatal(err)
}
if diff := cmp.Diff(f32s, []float32{0, 1, 4, 5, 8, 9}); diff != "" {
t.Errorf("unexpected data (-want +got):\n%s", diff)
}
}
{
tt, ok := next()
if !ok {
t.Fatal("expected at least one split")
}
if tt.Name != "a.y" {
t.Fatal("expected name 'a.y', got", tt.Name)
}
if diff := cmp.Diff(tt.Shape, []uint64{3, 2}); diff != "" {
t.Errorf("unexpected shape (-want +got):\n%s", diff)
}
var b bytes.Buffer
if _, err := tt.WriteTo(&b); err != nil {
t.Fatal(err)
}
f32s := make([]float32, mul(tt.Shape))
if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil {
t.Fatal(err)
}
if diff := cmp.Diff(f32s, []float32{2, 6, 10, 3, 7, 11}); diff != "" {
t.Errorf("unexpected data (-want +got):\n%s", diff)
}
}
})
})
t.Run("3d", func(t *testing.T) {
r := fakeTensor{
name: "a.b",
shape: []uint64{3, 4, 2},
data: []float32{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23},
}
t.Run("no split", func(t *testing.T) {
for tt := range splitDim(&r, 0, split{Replacer: strings.NewReplacer("a", "x")}) {
if tt.Name != "x.b" {
t.Fatalf("expected name 'x', got '%s'", tt.Name)
}
if diff := cmp.Diff(tt.Shape, []uint64{3, 4, 2}); diff != "" {
t.Errorf("unexpected shape (-want +got):\n%s", diff)
}
var b bytes.Buffer
if _, err := tt.WriteTo(&b); err != nil {
t.Fatal(err)
}
f32s := make([]float32, mul(tt.Shape))
if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil {
t.Fatal(err)
}
if diff := cmp.Diff(f32s, []float32{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}); diff != "" {
t.Errorf("unexpected data (-want +got):\n%s", diff)
}
}
})
t.Run("even split", func(t *testing.T) {
next, stop := iter.Pull(splitDim(&r, 1,
split{Replacer: strings.NewReplacer("a", "x")},
split{Replacer: strings.NewReplacer("b", "y")},
))
defer stop()
{
tt, ok := next()
if !ok {
t.Fatal("expected at least one split")
}
if tt.Name != "x.b" {
t.Fatal("expected name 'x.b', got", tt.Name)
}
if diff := cmp.Diff(tt.Shape, []uint64{3, 2, 2}); diff != "" {
t.Errorf("unexpected shape (-want +got):\n%s", diff)
}
var b bytes.Buffer
if _, err := tt.WriteTo(&b); err != nil {
t.Fatal(err)
}
f32s := make([]float32, mul(tt.Shape))
if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil {
t.Fatal(err)
}
if diff := cmp.Diff(f32s, []float32{0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19}); diff != "" {
t.Errorf("unexpected data (-want +got):\n%s", diff)
}
}
{
tt, ok := next()
if !ok {
t.Fatal("expected at least one split")
}
if tt.Name != "a.y" {
t.Fatal("expected name 'a.y', got", tt.Name)
}
if diff := cmp.Diff(tt.Shape, []uint64{3, 2, 2}); diff != "" {
t.Errorf("unexpected shape (-want +got):\n%s", diff)
}
var b bytes.Buffer
if _, err := tt.WriteTo(&b); err != nil {
t.Fatal(err)
}
f32s := make([]float32, mul(tt.Shape))
if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil {
t.Fatal(err)
}
if diff := cmp.Diff(f32s, []float32{4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23}); diff != "" {
t.Errorf("unexpected data (-want +got):\n%s", diff)
}
}
})
t.Run("uneven split", func(t *testing.T) {
next, stop := iter.Pull(splitDim(&r, 0,
split{Replacer: strings.NewReplacer("a", "x"), dim: 2},
split{Replacer: strings.NewReplacer("b", "y"), dim: 1},
))
defer stop()
{
tt, ok := next()
if !ok {
t.Fatal("expected at least one split")
}
if tt.Name != "x.b" {
t.Fatal("expected name 'x.b', got", tt.Name)
}
if diff := cmp.Diff(tt.Shape, []uint64{2, 4, 2}); diff != "" {
t.Errorf("unexpected shape (-want +got):\n%s", diff)
}
var b bytes.Buffer
if _, err := tt.WriteTo(&b); err != nil {
t.Fatal(err)
}
f32s := make([]float32, mul(tt.Shape))
if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil {
t.Fatal(err)
}
if diff := cmp.Diff(f32s, []float32{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}); diff != "" {
t.Errorf("unexpected data (-want +got):\n%s", diff)
}
}
{
tt, ok := next()
if !ok {
t.Fatal("expected at least one split")
}
if tt.Name != "a.y" {
t.Fatal("expected name 'a.y', got", tt.Name)
}
if diff := cmp.Diff(tt.Shape, []uint64{1, 4, 2}); diff != "" {
t.Errorf("unexpected shape (-want +got):\n%s", diff)
}
var b bytes.Buffer
if _, err := tt.WriteTo(&b); err != nil {
t.Fatal(err)
}
f32s := make([]float32, mul(tt.Shape))
if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil {
t.Fatal(err)
}
if diff := cmp.Diff(f32s, []float32{16, 17, 18, 19, 20, 21, 22, 23}); diff != "" {
t.Errorf("unexpected data (-want +got):\n%s", diff)
}
}
})
t.Run("three way split", func(t *testing.T) {
next, stop := iter.Pull(splitDim(&r, 0,
split{Replacer: strings.NewReplacer("a", "x"), dim: 1},
split{Replacer: strings.NewReplacer("b", "y"), dim: 1},
split{Replacer: strings.NewReplacer("b", "z"), dim: 1},
))
defer stop()
{
tt, ok := next()
if !ok {
t.Fatal("expected at least one split")
}
if tt.Name != "x.b" {
t.Fatal("expected name 'x.b', got", tt.Name)
}
if diff := cmp.Diff(tt.Shape, []uint64{1, 4, 2}); diff != "" {
t.Errorf("unexpected shape (-want +got):\n%s", diff)
}
var b bytes.Buffer
if _, err := tt.WriteTo(&b); err != nil {
t.Fatal(err)
}
f32s := make([]float32, mul(tt.Shape))
if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil {
t.Fatal(err)
}
if diff := cmp.Diff(f32s, []float32{0, 1, 2, 3, 4, 5, 6, 7}); diff != "" {
t.Errorf("unexpected data (-want +got):\n%s", diff)
}
}
{
tt, ok := next()
if !ok {
t.Fatal("expected at least one split")
}
if tt.Name != "a.y" {
t.Fatal("expected name 'x.b', got", tt.Name)
}
if diff := cmp.Diff(tt.Shape, []uint64{1, 4, 2}); diff != "" {
t.Errorf("unexpected shape (-want +got):\n%s", diff)
}
var b bytes.Buffer
if _, err := tt.WriteTo(&b); err != nil {
t.Fatal(err)
}
f32s := make([]float32, mul(tt.Shape))
if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil {
t.Fatal(err)
}
if diff := cmp.Diff(f32s, []float32{8, 9, 10, 11, 12, 13, 14, 15}); diff != "" {
t.Errorf("unexpected data (-want +got):\n%s", diff)
}
}
{
tt, ok := next()
if !ok {
t.Fatal("expected at least one split")
}
if tt.Name != "a.z" {
t.Fatal("expected name 'x.b', got", tt.Name)
}
if diff := cmp.Diff(tt.Shape, []uint64{1, 4, 2}); diff != "" {
t.Errorf("unexpected shape (-want +got):\n%s", diff)
}
var b bytes.Buffer
if _, err := tt.WriteTo(&b); err != nil {
t.Fatal(err)
}
f32s := make([]float32, mul(tt.Shape))
if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil {
t.Fatal(err)
}
if diff := cmp.Diff(f32s, []float32{16, 17, 18, 19, 20, 21, 22, 23}); diff != "" {
t.Errorf("unexpected data (-want +got):\n%s", diff)
}
}
})
t.Run("uneven three way split", func(t *testing.T) {
next, stop := iter.Pull(splitDim(&r, 1,
split{Replacer: strings.NewReplacer("a", "x"), dim: 2},
split{Replacer: strings.NewReplacer("b", "y"), dim: 1},
split{Replacer: strings.NewReplacer("b", "z"), dim: 1},
))
defer stop()
{
tt, ok := next()
if !ok {
t.Fatal("expected at least one split")
}
if tt.Name != "x.b" {
t.Fatal("expected name 'x.b', got", tt.Name)
}
if diff := cmp.Diff(tt.Shape, []uint64{3, 2, 2}); diff != "" {
t.Errorf("unexpected shape (-want +got):\n%s", diff)
}
var b bytes.Buffer
if _, err := tt.WriteTo(&b); err != nil {
t.Fatal(err)
}
f32s := make([]float32, mul(tt.Shape))
if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil {
t.Fatal(err)
}
if diff := cmp.Diff(f32s, []float32{0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19}); diff != "" {
t.Errorf("unexpected data (-want +got):\n%s", diff)
}
}
{
tt, ok := next()
if !ok {
t.Fatal("expected at least one split")
}
if tt.Name != "a.y" {
t.Fatal("expected name 'x.b', got", tt.Name)
}
if diff := cmp.Diff(tt.Shape, []uint64{3, 1, 2}); diff != "" {
t.Errorf("unexpected shape (-want +got):\n%s", diff)
}
var b bytes.Buffer
if _, err := tt.WriteTo(&b); err != nil {
t.Fatal(err)
}
f32s := make([]float32, mul(tt.Shape))
if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil {
t.Fatal(err)
}
if diff := cmp.Diff(f32s, []float32{4, 5, 12, 13, 20, 21}); diff != "" {
t.Errorf("unexpected data (-want +got):\n%s", diff)
}
}
{
tt, ok := next()
if !ok {
t.Fatal("expected at least one split")
}
if tt.Name != "a.z" {
t.Fatal("expected name 'x.b', got", tt.Name)
}
if diff := cmp.Diff(tt.Shape, []uint64{3, 1, 2}); diff != "" {
t.Errorf("unexpected shape (-want +got):\n%s", diff)
}
var b bytes.Buffer
if _, err := tt.WriteTo(&b); err != nil {
t.Fatal(err)
}
f32s := make([]float32, mul(tt.Shape))
if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil {
t.Fatal(err)
}
if diff := cmp.Diff(f32s, []float32{6, 7, 14, 15, 22, 23}); diff != "" {
t.Errorf("unexpected data (-want +got):\n%s", diff)
}
}
})
})
}
func TestMerge(t *testing.T) {
unmatched := []Tensor{
&fakeTensor{
name: "a.0.b",
shape: []uint64{5, 2},
data: []float32{10, 11, 12, 13, 14, 15, 16, 17, 18, 19},
},
&fakeTensor{
name: "a.1.b",
shape: []uint64{5, 2},
data: []float32{20, 21, 22, 23, 24, 25, 26, 27, 28, 29},
},
&fakeTensor{
name: "c.0.d",
shape: []uint64{5, 2},
data: []float32{30, 31, 32, 33, 34, 35, 36, 37, 38, 39},
},
&fakeTensor{
name: "c.1.d",
shape: []uint64{5, 2},
data: []float32{40, 41, 42, 43, 44, 45, 46, 47, 48, 49},
},
&fakeTensor{
name: "e.0.f",
shape: []uint64{5, 2},
data: []float32{50, 51, 52, 53, 54, 55, 56, 57, 58, 59},
},
}
checkMatched := func(t *testing.T, n int, matched []*ggml.Tensor) {
for i := range n {
got := matched[i]
if diff := cmp.Diff([]uint64{2, 5, 2}, got.Shape); diff != "" {
t.Errorf("unexpected (-want +got):\n%s", diff)
}
var b bytes.Buffer
if _, err := got.WriteTo(&b); err != nil {
t.Fatal(err)
}
f32s := make([]float32, 20)
if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil {
t.Fatal(err)
}
offset := 10 + (i * 20)
want := make([]float32, 20)
for j := range 20 {
want[j] = float32(offset + j)
}
if diff := cmp.Diff(want, f32s); diff != "" {
t.Errorf("unexpected data (-want +got):\n%s", diff)
}
}
}
t.Run("single merge", func(t *testing.T) {
matched, unmatched := mergeTensors(unmatched, merge{"a.*.b", "a.b"})
if len(unmatched) != 3 {
t.Error("expected 3 remaining tensors, got", len(unmatched))
}
if len(matched) != 1 {
t.Error("expected 1 merged tensor, got", len(matched))
}
checkMatched(t, 1, matched)
})
t.Run("multiple merges", func(t *testing.T) {
matched, unmatched := mergeTensors(unmatched, merge{"a.*.b", "a.b"}, merge{"c.*.d", "c.d"})
if len(unmatched) != 1 {
t.Error("expected 1 remaining tensors, got", len(unmatched))
}
if len(matched) != 2 {
t.Error("expected 2 merged tensor, got", len(matched))
}
checkMatched(t, 2, matched)
})
t.Run("no match", func(t *testing.T) {
matched, unmatched := mergeTensors(unmatched, merge{"x.*.y", "x.y"})
if len(unmatched) != 5 {
t.Error("expected 5 remaining tensors, got", len(unmatched))
}
if len(matched) != 0 {
t.Error("expected no merged tensors, got", len(matched))
}
})
}