From 0a066cfd91abdddc6ee172776974a6720a3072d3 Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Fri, 20 Jun 2025 11:11:40 -0700 Subject: [PATCH] Reapply "feat: incremental gguf parser (#10822)" (#11114) (#11119) * Reapply "feat: incremental gguf parser (#10822)" (#11114) This reverts commit a6e64fbdf28f0d6cb97cc7f022ca493b905fe895. * fix older ggufs --- fs/gguf/gguf.go | 347 ++++++++++++++++++++++++++++++++++++ fs/gguf/gguf_test.go | 249 ++++++++++++++++++++++++++ fs/gguf/keyvalue.go | 90 ++++++++++ fs/gguf/keyvalue_test.go | 208 +++++++++++++++++++++ fs/gguf/lazy.go | 89 +++++++++ fs/gguf/reader.go | 23 +++ fs/gguf/tensor.go | 288 ++++++++++++++++++++++++++++++ go.mod | 2 +- go.sum | 4 +- server/images.go | 24 ++- server/images_test.go | 165 +++++------------ server/quantization_test.go | 12 +- server/sched_test.go | 20 +-- 13 files changed, 1357 insertions(+), 164 deletions(-) create mode 100644 fs/gguf/gguf.go create mode 100644 fs/gguf/gguf_test.go create mode 100644 fs/gguf/keyvalue.go create mode 100644 fs/gguf/keyvalue_test.go create mode 100644 fs/gguf/lazy.go create mode 100644 fs/gguf/reader.go create mode 100644 fs/gguf/tensor.go diff --git a/fs/gguf/gguf.go b/fs/gguf/gguf.go new file mode 100644 index 000000000..bbb9bb410 --- /dev/null +++ b/fs/gguf/gguf.go @@ -0,0 +1,347 @@ +package gguf + +import ( + "bytes" + "cmp" + "encoding/binary" + "errors" + "fmt" + "io" + "iter" + "os" + "slices" + "strings" +) + +const ( + typeUint8 uint32 = iota + typeInt8 + typeUint16 + typeInt16 + typeUint32 + typeInt32 + typeFloat32 + typeBool + typeString + typeArray + typeUint64 + typeInt64 + typeFloat64 +) + +var ErrUnsupported = errors.New("unsupported") + +type File struct { + Magic [4]byte + Version uint32 + + keyValues *lazy[KeyValue] + tensors *lazy[TensorInfo] + offset int64 + + file *os.File + reader *bufferedReader + bts []byte +} + +func Open(path string) (f *File, err error) { + f = &File{bts: make([]byte, 4096)} + f.file, err = os.Open(path) + if err != nil { + return nil, err + } + + f.reader = newBufferedReader(f.file, 32<<10) + + if err := binary.Read(f.reader, binary.LittleEndian, &f.Magic); err != nil { + return nil, err + } + + if bytes.Equal(f.Magic[:], []byte("gguf")) { + return nil, fmt.Errorf("%w file type %v", ErrUnsupported, f.Magic) + } + + if err := binary.Read(f.reader, binary.LittleEndian, &f.Version); err != nil { + return nil, err + } + + if f.Version < 2 { + return nil, fmt.Errorf("%w version %v", ErrUnsupported, f.Version) + } + + f.tensors, err = newLazy(f, f.readTensor) + if err != nil { + return nil, err + } + + f.tensors.successFunc = func() error { + offset := f.reader.offset + + alignment := cmp.Or(f.KeyValue("general.alignment").Int(), 32) + f.offset = offset + (alignment-offset%alignment)%alignment + return nil + } + + f.keyValues, err = newLazy(f, f.readKeyValue) + if err != nil { + return nil, err + } + + return f, nil +} + +func (f *File) readTensor() (TensorInfo, error) { + name, err := readString(f) + if err != nil { + return TensorInfo{}, err + } + + dims, err := read[uint32](f) + if err != nil { + return TensorInfo{}, err + } + + shape := make([]uint64, dims) + for i := range dims { + shape[i], err = read[uint64](f) + if err != nil { + return TensorInfo{}, err + } + } + + type_, err := read[uint32](f) + if err != nil { + return TensorInfo{}, err + } + + offset, err := read[uint64](f) + if err != nil { + return TensorInfo{}, err + } + + return TensorInfo{ + Name: name, + Offset: offset, + Shape: shape, + Type: TensorType(type_), + }, nil +} + +func (f *File) readKeyValue() (KeyValue, error) { + key, err := readString(f) + if err != nil { + return KeyValue{}, err + } + + t, err := read[uint32](f) + if err != nil { + return KeyValue{}, err + } + + value, err := func() (any, error) { + switch t { + case typeUint8: + return read[uint8](f) + case typeInt8: + return read[int8](f) + case typeUint16: + return read[uint16](f) + case typeInt16: + return read[int16](f) + case typeUint32: + return read[uint32](f) + case typeInt32: + return read[int32](f) + case typeUint64: + return read[uint64](f) + case typeInt64: + return read[int64](f) + case typeFloat32: + return read[float32](f) + case typeFloat64: + return read[float64](f) + case typeBool: + return read[bool](f) + case typeString: + return readString(f) + case typeArray: + return readArray(f) + default: + return nil, fmt.Errorf("%w type %d", ErrUnsupported, t) + } + }() + if err != nil { + return KeyValue{}, err + } + + return KeyValue{ + Key: key, + Value: Value{value}, + }, nil +} + +func read[T any](f *File) (t T, err error) { + err = binary.Read(f.reader, binary.LittleEndian, &t) + return t, err +} + +func readString(f *File) (string, error) { + n, err := read[uint64](f) + if err != nil { + return "", err + } + + if int(n) > len(f.bts) { + f.bts = make([]byte, n) + } + + bts := f.bts[:n] + if _, err := io.ReadFull(f.reader, bts); err != nil { + return "", err + } + defer clear(bts) + + return string(bts), nil +} + +func readArray(f *File) (any, error) { + t, err := read[uint32](f) + if err != nil { + return nil, err + } + + n, err := read[uint64](f) + if err != nil { + return nil, err + } + + switch t { + case typeUint8: + return readArrayData[uint8](f, n) + case typeInt8: + return readArrayData[int8](f, n) + case typeUint16: + return readArrayData[uint16](f, n) + case typeInt16: + return readArrayData[int16](f, n) + case typeUint32: + return readArrayData[uint32](f, n) + case typeInt32: + return readArrayData[int32](f, n) + case typeUint64: + return readArrayData[uint64](f, n) + case typeInt64: + return readArrayData[int64](f, n) + case typeFloat32: + return readArrayData[float32](f, n) + case typeFloat64: + return readArrayData[float64](f, n) + case typeBool: + return readArrayData[bool](f, n) + case typeString: + return readArrayString(f, n) + default: + return nil, fmt.Errorf("%w type %d", ErrUnsupported, t) + } +} + +func readArrayData[T any](f *File, n uint64) (s []T, err error) { + s = make([]T, n) + for i := range n { + e, err := read[T](f) + if err != nil { + return nil, err + } + + s[i] = e + } + + return s, nil +} + +func readArrayString(f *File, n uint64) (s []string, err error) { + s = make([]string, n) + for i := range n { + e, err := readString(f) + if err != nil { + return nil, err + } + + s[i] = e + } + + return s, nil +} + +func (f *File) Close() error { + f.keyValues.stop() + f.tensors.stop() + return f.file.Close() +} + +func (f *File) KeyValue(key string) KeyValue { + if !strings.HasPrefix(key, "general.") && !strings.HasPrefix(key, "tokenizer.") { + key = f.KeyValue("general.architecture").String() + "." + key + } + + if index := slices.IndexFunc(f.keyValues.values, func(kv KeyValue) bool { + return kv.Key == key + }); index >= 0 { + return f.keyValues.values[index] + } + + for keyValue, ok := f.keyValues.next(); ok; keyValue, ok = f.keyValues.next() { + if keyValue.Key == key { + return keyValue + } + } + + return KeyValue{} +} + +func (f *File) NumKeyValues() int { + return int(f.keyValues.count) +} + +func (f *File) KeyValues() iter.Seq2[int, KeyValue] { + return f.keyValues.All() +} + +func (f *File) TensorInfo(name string) TensorInfo { + if index := slices.IndexFunc(f.tensors.values, func(t TensorInfo) bool { + return t.Name == name + }); index >= 0 { + return f.tensors.values[index] + } + + // fast-forward through key values if we haven't already + _ = f.keyValues.rest() + for tensor, ok := f.tensors.next(); ok; tensor, ok = f.tensors.next() { + if tensor.Name == name { + return tensor + } + } + + return TensorInfo{} +} + +func (f *File) NumTensors() int { + return int(f.tensors.count) +} + +func (f *File) TensorInfos() iter.Seq2[int, TensorInfo] { + // fast forward through key values if we haven't already + f.keyValues.rest() + return f.tensors.All() +} + +func (f *File) TensorReader(name string) (TensorInfo, io.Reader, error) { + t := f.TensorInfo(name) + if t.NumBytes() == 0 { + return TensorInfo{}, nil, fmt.Errorf("tensor %s not found", name) + } + + // fast forward through tensor info if we haven't already + _ = f.tensors.rest() + return t, io.NewSectionReader(f.file, f.offset+int64(t.Offset), t.NumBytes()), nil +} diff --git a/fs/gguf/gguf_test.go b/fs/gguf/gguf_test.go new file mode 100644 index 000000000..eea28a480 --- /dev/null +++ b/fs/gguf/gguf_test.go @@ -0,0 +1,249 @@ +package gguf_test + +import ( + "bytes" + "os" + "strconv" + "strings" + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + "github.com/ollama/ollama/fs/ggml" + "github.com/ollama/ollama/fs/gguf" +) + +func createBinFile(tb testing.TB) string { + tb.Helper() + f, err := os.CreateTemp(tb.TempDir(), "") + if err != nil { + tb.Fatal(err) + } + defer f.Close() + + kv := ggml.KV{ + "general.architecture": "llama", + "llama.block_count": uint32(8), + "llama.embedding_length": uint32(3), + "llama.attention.head_count": uint32(2), + "llama.attention.head_count_kv": uint32(2), + "llama.attention.key_length": uint32(3), + "llama.rope.dimension_count": uint32(4), + "llama.rope.freq_base": float32(10000.0), + "llama.rope.freq_scale": float32(1.0), + "llama.attention.layer_norm_rms_epsilon": float32(1e-6), + "tokenizer.ggml.eos_token_id": uint32(0), + "tokenizer.ggml.eos_token_ids": []int32{1, 2, 3}, + "tokenizer.ggml.tokens": []string{"hello", "world"}, + "tokenizer.ggml.scores": []float32{0, 1}, + } + + tensors := []*ggml.Tensor{ + { + Name: "token_embd.weight", + Kind: 0, + Shape: []uint64{2, 3}, + WriterTo: bytes.NewBuffer(make([]byte, 4*2*3)), + }, + { + Name: "output.weight", + Kind: 0, + Shape: []uint64{3, 2}, + WriterTo: bytes.NewBuffer(make([]byte, 4*3*2)), + }, + } + + for i := range 8 { + tensors = append(tensors, &ggml.Tensor{ + Name: "blk." + strconv.Itoa(i) + ".attn_q.weight", + Kind: 0, + Shape: []uint64{3, 3}, + WriterTo: bytes.NewBuffer(make([]byte, 4*3*3)), + }, &ggml.Tensor{ + Name: "blk." + strconv.Itoa(i) + ".attn_k.weight", + Kind: 0, + Shape: []uint64{3, 3}, + WriterTo: bytes.NewBuffer(make([]byte, 4*3*3)), + }, &ggml.Tensor{ + Name: "blk." + strconv.Itoa(i) + ".attn_v.weight", + Kind: 0, + Shape: []uint64{3, 3}, + WriterTo: bytes.NewBuffer(make([]byte, 4*3*3)), + }, &ggml.Tensor{ + Name: "blk." + strconv.Itoa(i) + ".attn_output.weight", + Kind: 0, + Shape: []uint64{3, 3}, + WriterTo: bytes.NewBuffer(make([]byte, 4*3*3)), + }) + } + + if err := ggml.WriteGGUF(f, kv, tensors); err != nil { + tb.Fatal(err) + } + + return f.Name() +} + +func TestRead(t *testing.T) { + f, err := gguf.Open(createBinFile(t)) + if err != nil { + t.Fatal(err) + } + defer f.Close() + + if got := f.KeyValue("does.not.exist").Valid(); got { + t.Errorf(`KeyValue("does.not.exist").Exists() = %v, want false`, got) + } + + if got := f.KeyValue("general.architecture").String(); got != "llama" { + t.Errorf(`KeyValue("general.architecture").String() = %q, want %q`, got, "llama") + } + + if got := f.TensorInfo("token_embd.weight"); got.Name != "token_embd.weight" { + t.Errorf(`TensorInfo("token_embd.weight").Name = %q, want %q`, got.Name, "token_embd.weight") + } else if diff := cmp.Diff(got.Shape, []uint64{2, 3}); diff != "" { + t.Errorf(`TensorInfo("token_embd.weight").Shape mismatch (-got +want):\n%s`, diff) + } else if got.Type != gguf.TensorTypeF32 { + t.Errorf(`TensorInfo("token_embd.weight").Type = %d, want %d`, got.Type, gguf.TensorTypeF32) + } + + if got := f.KeyValue("block_count").Uint(); got != 8 { + t.Errorf(`KeyValue("block_count").Uint() = %d, want %d`, got, 8) + } + + if diff := cmp.Diff(f.KeyValue("tokenizer.ggml.tokens").Strings(), []string{"hello", "world"}); diff != "" { + t.Errorf("KeyValue(\"tokenizer.ggml.tokens\").Strings() mismatch (-got +want):\n%s", diff) + } + + if diff := cmp.Diff(f.KeyValue("tokenizer.ggml.scores").Floats(), []float64{0, 1}); diff != "" { + t.Errorf("KeyValue(\"tokenizer.ggml.scores\").Ints() mismatch (-got +want):\n%s", diff) + } + + var kvs []string + for _, kv := range f.KeyValues() { + if !kv.Valid() { + t.Error("found invalid key-value pair:", kv) + } + + kvs = append(kvs, kv.Key) + } + + if len(kvs) != f.NumKeyValues() { + t.Errorf("iterated key count = %d, want %d", len(kvs), f.NumKeyValues()) + } + + if diff := cmp.Diff(kvs, []string{ + "general.architecture", + "llama.block_count", + "llama.embedding_length", + "llama.attention.head_count", + "llama.attention.head_count_kv", + "llama.attention.key_length", + "llama.rope.dimension_count", + "llama.rope.freq_base", + "llama.rope.freq_scale", + "llama.attention.layer_norm_rms_epsilon", + "tokenizer.ggml.eos_token_id", + "tokenizer.ggml.eos_token_ids", + "tokenizer.ggml.tokens", + "tokenizer.ggml.scores", + }, cmpopts.SortSlices(strings.Compare)); diff != "" { + t.Errorf("KeyValues() mismatch (-got +want):\n%s", diff) + } + + var tis []string + for _, ti := range f.TensorInfos() { + if !ti.Valid() { + t.Error("found invalid tensor info:", ti) + } + + tis = append(tis, ti.Name) + } + + if len(tis) != f.NumTensors() { + t.Errorf("iterated tensor count = %d, want %d", len(tis), f.NumTensors()) + } + + if diff := cmp.Diff(tis, []string{ + "token_embd.weight", + "output.weight", + "blk.0.attn_q.weight", + "blk.0.attn_k.weight", + "blk.0.attn_v.weight", + "blk.0.attn_output.weight", + "blk.1.attn_q.weight", + "blk.1.attn_k.weight", + "blk.1.attn_v.weight", + "blk.1.attn_output.weight", + "blk.2.attn_q.weight", + "blk.2.attn_k.weight", + "blk.2.attn_v.weight", + "blk.2.attn_output.weight", + "blk.3.attn_q.weight", + "blk.3.attn_k.weight", + "blk.3.attn_v.weight", + "blk.3.attn_output.weight", + "blk.4.attn_q.weight", + "blk.4.attn_k.weight", + "blk.4.attn_v.weight", + "blk.4.attn_output.weight", + "blk.5.attn_q.weight", + "blk.5.attn_k.weight", + "blk.5.attn_v.weight", + "blk.5.attn_output.weight", + "blk.6.attn_q.weight", + "blk.6.attn_k.weight", + "blk.6.attn_v.weight", + "blk.6.attn_output.weight", + "blk.7.attn_q.weight", + "blk.7.attn_k.weight", + "blk.7.attn_v.weight", + "blk.7.attn_output.weight", + }, cmpopts.SortSlices(strings.Compare)); diff != "" { + t.Errorf("TensorInfos() mismatch (-got +want):\n%s", diff) + } + + ti, r, err := f.TensorReader("output.weight") + if err != nil { + t.Fatalf(`TensorReader("output.weight") error: %v`, err) + } + + if ti.Name != "output.weight" { + t.Errorf(`TensorReader("output.weight").Name = %q, want %q`, ti.Name, "output.weight") + } else if diff := cmp.Diff(ti.Shape, []uint64{3, 2}); diff != "" { + t.Errorf(`TensorReader("output.weight").Shape mismatch (-got +want):\n%s`, diff) + } else if ti.Type != gguf.TensorTypeF32 { + t.Errorf(`TensorReader("output.weight").Type = %d, want %d`, ti.Type, gguf.TensorTypeF32) + } + + var b bytes.Buffer + if _, err := b.ReadFrom(r); err != nil { + t.Fatalf(`ReadFrom TensorReader("output.weight") error: %v`, err) + } + + if b.Len() != int(ti.NumBytes()) { + t.Errorf(`ReadFrom TensorReader("output.weight") length = %d, want %d`, b.Len(), ti.NumBytes()) + } +} + +func BenchmarkRead(b *testing.B) { + b.ReportAllocs() + + p := createBinFile(b) + for b.Loop() { + f, err := gguf.Open(p) + if err != nil { + b.Fatal(err) + } + + if got := f.KeyValue("general.architecture").String(); got != "llama" { + b.Errorf("got = %q, want %q", got, "llama") + } + + // Iterate through some tensors + for range f.TensorInfos() { + } + + f.Close() + } +} diff --git a/fs/gguf/keyvalue.go b/fs/gguf/keyvalue.go new file mode 100644 index 000000000..5843326c1 --- /dev/null +++ b/fs/gguf/keyvalue.go @@ -0,0 +1,90 @@ +package gguf + +import ( + "reflect" + "slices" +) + +type KeyValue struct { + Key string + Value +} + +func (kv KeyValue) Valid() bool { + return kv.Key != "" && kv.Value.value != nil +} + +type Value struct { + value any +} + +func value[T any](v Value, kinds ...reflect.Kind) (t T) { + vv := reflect.ValueOf(v.value) + if slices.Contains(kinds, vv.Kind()) { + t = vv.Convert(reflect.TypeOf(t)).Interface().(T) + } + return +} + +func values[T any](v Value, kinds ...reflect.Kind) (ts []T) { + switch vv := reflect.ValueOf(v.value); vv.Kind() { + case reflect.Slice: + if slices.Contains(kinds, vv.Type().Elem().Kind()) { + ts = make([]T, vv.Len()) + for i := range vv.Len() { + ts[i] = vv.Index(i).Convert(reflect.TypeOf(ts[i])).Interface().(T) + } + } + } + return +} + +// Int returns Value as a signed integer. If it is not a signed integer, it returns 0. +func (v Value) Int() int64 { + return value[int64](v, reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64) +} + +// Ints returns Value as a signed integer slice. If it is not a signed integer slice, it returns nil. +func (v Value) Ints() (i64s []int64) { + return values[int64](v, reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64) +} + +// Uint converts an unsigned integer value to uint64. If the value is not a unsigned integer, it returns 0. +func (v Value) Uint() uint64 { + return value[uint64](v, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64) +} + +// Uints returns Value as a unsigned integer slice. If it is not a unsigned integer slice, it returns nil. +func (v Value) Uints() (u64s []uint64) { + return values[uint64](v, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64) +} + +// Float returns Value as a float. If it is not a float, it returns 0. +func (v Value) Float() float64 { + return value[float64](v, reflect.Float32, reflect.Float64) +} + +// Floats returns Value as a float slice. If it is not a float slice, it returns nil. +func (v Value) Floats() (f64s []float64) { + return values[float64](v, reflect.Float32, reflect.Float64) +} + +// Bool returns Value as a boolean. If it is not a boolean, it returns false. +func (v Value) Bool() bool { + return value[bool](v, reflect.Bool) +} + +// Bools returns Value as a boolean slice. If it is not a boolean slice, it returns nil. +func (v Value) Bools() (bools []bool) { + return values[bool](v, reflect.Bool) +} + +// String returns Value as a string. If it is not a string, it returns an empty string. +func (v Value) String() string { + return value[string](v, reflect.String) +} + +// Strings returns Value as a string slice. If it is not a string slice, it returns nil. +func (v Value) Strings() (strings []string) { + return values[string](v, reflect.String) +} diff --git a/fs/gguf/keyvalue_test.go b/fs/gguf/keyvalue_test.go new file mode 100644 index 000000000..2caacc538 --- /dev/null +++ b/fs/gguf/keyvalue_test.go @@ -0,0 +1,208 @@ +package gguf + +import ( + "testing" + + "github.com/google/go-cmp/cmp" +) + +func split(name string, values map[string][]any) (matched []any, unmatched []any) { + for key, value := range values { + if key == name { + matched = value + } else { + unmatched = append(unmatched, value...) + } + } + return +} + +func TestValue(t *testing.T) { + values := map[string][]any{ + "int64": {int(42), int8(42), int16(42), int32(42), int64(42)}, + "uint64": {uint(42), uint8(42), uint16(42), uint32(42), uint64(42)}, + "float64": {float32(42), float64(42)}, + "string": {"42", "hello"}, + "bool": {true, false}, + } + + t.Run("int64", func(t *testing.T) { + matched, unmatched := split("int64", values) + for _, v := range matched { + kv := KeyValue{"key", Value{v}} + if i64 := kv.Int(); i64 != 42 { + t.Errorf("expected 42, got %d", i64) + } + } + + for _, v := range unmatched { + kv := KeyValue{"key", Value{v}} + if i64 := kv.Int(); i64 != 0 { + t.Errorf("expected 42, got %d", i64) + } + } + }) + + t.Run("uint64", func(t *testing.T) { + matched, unmatched := split("uint64", values) + for _, v := range matched { + kv := KeyValue{"key", Value{v}} + if u64 := kv.Uint(); u64 != 42 { + t.Errorf("expected 42, got %d", u64) + } + } + + for _, v := range unmatched { + kv := KeyValue{"key", Value{v}} + if u64 := kv.Uint(); u64 != 0 { + t.Errorf("expected 42, got %d", u64) + } + } + }) + + t.Run("float64", func(t *testing.T) { + matched, unmatched := split("float64", values) + for _, v := range matched { + kv := KeyValue{"key", Value{v}} + if f64 := kv.Float(); f64 != 42 { + t.Errorf("expected 42, got %f", f64) + } + } + + for _, v := range unmatched { + kv := KeyValue{"key", Value{v}} + if f64 := kv.Float(); f64 != 0 { + t.Errorf("expected 42, got %f", f64) + } + } + }) + + t.Run("string", func(t *testing.T) { + matched, unmatched := split("string", values) + for _, v := range matched { + kv := KeyValue{"key", Value{v}} + if s := kv.String(); s != v { + t.Errorf("expected 42, got %s", s) + } + } + + for _, v := range unmatched { + kv := KeyValue{"key", Value{v}} + if s := kv.String(); s != "" { + t.Errorf("expected 42, got %s", s) + } + } + }) + + t.Run("bool", func(t *testing.T) { + matched, unmatched := split("bool", values) + for _, v := range matched { + kv := KeyValue{"key", Value{v}} + if b := kv.Bool(); b != v { + t.Errorf("expected true, got %v", b) + } + } + + for _, v := range unmatched { + kv := KeyValue{"key", Value{v}} + if b := kv.Bool(); b != false { + t.Errorf("expected false, got %v", b) + } + } + }) +} + +func TestValues(t *testing.T) { + values := map[string][]any{ + "int64s": {[]int{42}, []int8{42}, []int16{42}, []int32{42}, []int64{42}}, + "uint64s": {[]uint{42}, []uint8{42}, []uint16{42}, []uint32{42}, []uint64{42}}, + "float64s": {[]float32{42}, []float64{42}}, + "strings": {[]string{"42"}, []string{"hello"}}, + "bools": {[]bool{true}, []bool{false}}, + } + + t.Run("int64s", func(t *testing.T) { + matched, unmatched := split("int64s", values) + for _, v := range matched { + kv := KeyValue{"key", Value{v}} + if diff := cmp.Diff(kv.Ints(), []int64{42}); diff != "" { + t.Errorf("diff: %s", diff) + } + } + + for _, v := range unmatched { + kv := KeyValue{"key", Value{v}} + if i64s := kv.Ints(); i64s != nil { + t.Errorf("expected nil, got %v", i64s) + } + } + }) + + t.Run("uint64s", func(t *testing.T) { + matched, unmatched := split("uint64s", values) + for _, v := range matched { + kv := KeyValue{"key", Value{v}} + if diff := cmp.Diff(kv.Uints(), []uint64{42}); diff != "" { + t.Errorf("diff: %s", diff) + } + } + + for _, v := range unmatched { + kv := KeyValue{"key", Value{v}} + if u64s := kv.Uints(); u64s != nil { + t.Errorf("expected nil, got %v", u64s) + } + } + }) + + t.Run("float64s", func(t *testing.T) { + matched, unmatched := split("float64s", values) + for _, v := range matched { + kv := KeyValue{"key", Value{v}} + if diff := cmp.Diff(kv.Floats(), []float64{42}); diff != "" { + t.Errorf("diff: %s", diff) + } + } + + for _, v := range unmatched { + kv := KeyValue{"key", Value{v}} + if f64s := kv.Floats(); f64s != nil { + t.Errorf("expected nil, got %v", f64s) + } + } + }) + + t.Run("strings", func(t *testing.T) { + matched, unmatched := split("strings", values) + for _, v := range matched { + kv := KeyValue{"key", Value{v}} + if diff := cmp.Diff(kv.Strings(), v); diff != "" { + t.Errorf("diff: %s", diff) + } + } + + for _, v := range unmatched { + kv := KeyValue{"key", Value{v}} + if s := kv.Strings(); s != nil { + t.Errorf("expected nil, got %v", s) + } + } + }) + + t.Run("bools", func(t *testing.T) { + matched, unmatched := split("bools", values) + for _, v := range matched { + kv := KeyValue{"key", Value{v}} + if diff := cmp.Diff(kv.Bools(), v); diff != "" { + t.Errorf("diff: %s", diff) + } + } + + for _, v := range unmatched { + kv := KeyValue{"key", Value{v}} + if b := kv.Bools(); b != nil { + t.Errorf("expected nil, got %v", b) + } + } + }) +} diff --git a/fs/gguf/lazy.go b/fs/gguf/lazy.go new file mode 100644 index 000000000..16ab99093 --- /dev/null +++ b/fs/gguf/lazy.go @@ -0,0 +1,89 @@ +package gguf + +import ( + "encoding/binary" + "iter" + "log/slog" +) + +type lazy[T any] struct { + count uint64 + next func() (T, bool) + stop func() + values []T + + // successFunc is called when all values have been successfully read. + successFunc func() error +} + +func newLazy[T any](f *File, fn func() (T, error)) (*lazy[T], error) { + it := lazy[T]{} + if err := binary.Read(f.reader, binary.LittleEndian, &it.count); err != nil { + return nil, err + } + + it.values = make([]T, 0) + it.next, it.stop = iter.Pull(func(yield func(T) bool) { + for i := range it.count { + t, err := fn() + if err != nil { + slog.Error("error reading tensor", "index", i, "error", err) + return + } + + it.values = append(it.values, t) + if !yield(t) { + break + } + } + + if it.successFunc != nil { + it.successFunc() + } + }) + + return &it, nil +} + +func (g *lazy[T]) Values() iter.Seq[T] { + return func(yield func(T) bool) { + for _, v := range g.All() { + if !yield(v) { + break + } + } + } +} + +func (g *lazy[T]) All() iter.Seq2[int, T] { + return func(yield func(int, T) bool) { + for i := range int(g.count) { + if i < len(g.values) { + if !yield(i, g.values[i]) { + break + } + } else { + t, ok := g.next() + if !ok { + break + } + + if !yield(i, t) { + break + } + } + } + } +} + +func (g *lazy[T]) rest() (collected bool) { + for { + _, ok := g.next() + collected = collected || ok + if !ok { + break + } + } + + return collected +} diff --git a/fs/gguf/reader.go b/fs/gguf/reader.go new file mode 100644 index 000000000..0bd761840 --- /dev/null +++ b/fs/gguf/reader.go @@ -0,0 +1,23 @@ +package gguf + +import ( + "bufio" + "io" +) + +type bufferedReader struct { + offset int64 + *bufio.Reader +} + +func newBufferedReader(rs io.ReadSeeker, size int) *bufferedReader { + return &bufferedReader{ + Reader: bufio.NewReaderSize(rs, size), + } +} + +func (rs *bufferedReader) Read(p []byte) (n int, err error) { + n, err = rs.Reader.Read(p) + rs.offset += int64(n) + return n, err +} diff --git a/fs/gguf/tensor.go b/fs/gguf/tensor.go new file mode 100644 index 000000000..194c1d739 --- /dev/null +++ b/fs/gguf/tensor.go @@ -0,0 +1,288 @@ +package gguf + +import ( + "log/slog" + "strings" +) + +type TensorInfo struct { + Name string + Offset uint64 + Shape []uint64 + Type TensorType +} + +func (ti TensorInfo) Valid() bool { + return ti.Name != "" && ti.NumBytes() > 0 +} + +func (ti TensorInfo) NumValues() int64 { + var numItems int64 = 1 + for _, dim := range ti.Shape { + numItems *= int64(dim) + } + return numItems +} + +// NumBytes returns the number of bytes in the tensor. +func (ti TensorInfo) NumBytes() int64 { + return int64(float64(ti.NumValues()) * ti.Type.NumBytes()) +} + +func (ti TensorInfo) LogValue() slog.Value { + return slog.GroupValue( + slog.String("name", ti.Name), + slog.Int64("offset", int64(ti.Offset)), + slog.Any("shape", ti.Shape), + slog.Int64("num_values", ti.NumValues()), + slog.Int64("num_bytes", ti.NumBytes()), + slog.Any("type", ti.Type), + ) +} + +type TensorType uint32 + +const ( + TensorTypeF32 TensorType = iota + TensorTypeF16 + TensorTypeQ4_0 + TensorTypeQ4_1 + + // unexported // unused in gguf + tensorTypeQ4_2 + tensorTypeQ4_3 + + TensorTypeQ5_0 + TensorTypeQ5_1 + TensorTypeQ8_0 + TensorTypeQ8_1 + TensorTypeQ2_K + TensorTypeQ3_K + TensorTypeQ4_K + TensorTypeQ5_K + TensorTypeQ6_K + TensorTypeQ8_K + + // unexported // unquantizable by ollama + tensorTypeIQ2_XXS + tensorTypeIQ2_XS + tensorTypeIQ3_XXS + tensorTypeIQ1_S + tensorTypeIQ4_NL + tensorTypeIQ3_S + tensorTypeIQ2_S + tensorTypeIQ4_XS + + TensorTypeI8 + TensorTypeI16 + TensorTypeI32 + TensorTypeI64 + TensorTypeF64 + + // unexported // unquantizable by ollama + tensorTypeIQ1_M + + TensorTypeBF16 + + // unexported // unused in gguf + tensorTypeQ4_0_4_4 + tensorTypeQ4_0_4_8 + tensorTypeQ4_0_8_8 + + // unexported // unquantizable by ollama + tensorTypeTQ1_0 + tensorTypeTQ2_0 + + // unexported // unused in gguf + tensorTypeIQ4_NL_4_4 + tensorTypeIQ4_NL_4_8 + tensorTypeIQ4_NL_8_8 +) + +func (tt TensorType) NumBytes() float64 { + return float64(tt.typeSize()) / float64(tt.blockSize()) +} + +func (tt TensorType) typeSize() int64 { + switch tt { + case TensorTypeF32: + return 4 + case TensorTypeF16: + return 2 + case TensorTypeQ4_0: + return 2 + tt.blockSize()/2 + case TensorTypeQ4_1: + return 2 + 2 + tt.blockSize()/2 + case TensorTypeQ5_0: + return 2 + 4 + tt.blockSize()/2 + case TensorTypeQ5_1: + return 2 + 2 + 4 + tt.blockSize()/2 + case TensorTypeQ8_0: + return 2 + tt.blockSize() + case TensorTypeQ8_1: + return 2 + 2 + tt.blockSize() + case TensorTypeQ2_K: + return tt.blockSize()/16 + tt.blockSize()/4 + 2 + 2 + case TensorTypeQ3_K: + return tt.blockSize()/8 + tt.blockSize()/4 + 12 + 2 + case TensorTypeQ4_K: + return 2 + 2 + 12 + tt.blockSize()/2 + case TensorTypeQ5_K: + return 2 + 2 + 12 + tt.blockSize()/8 + tt.blockSize()/2 + case TensorTypeQ6_K: + return tt.blockSize()/2 + tt.blockSize()/4 + tt.blockSize()/16 + 2 + case TensorTypeQ8_K: + return 4 + tt.blockSize() + 2*tt.blockSize()/16 + case tensorTypeIQ2_XXS: + return 2 + 2*tt.blockSize()/8 + case tensorTypeIQ2_XS: + return 2 + 2*tt.blockSize()/8 + tt.blockSize()/32 + case tensorTypeIQ3_XXS: + return 2 + tt.blockSize()/4 + tt.blockSize()/8 + case tensorTypeIQ1_S: + return 2 + tt.blockSize()/8 + tt.blockSize()/16 + case tensorTypeIQ4_NL: + return 2 + tt.blockSize()/2 + case tensorTypeIQ3_S: + return 2 + tt.blockSize()/4 + tt.blockSize()/8 + tt.blockSize()/32 + 4 + case tensorTypeIQ2_S: + return 2 + tt.blockSize()/4 + tt.blockSize()/16 + case tensorTypeIQ4_XS: + return 2 + 2 + tt.blockSize()/2 + tt.blockSize()/64 + case TensorTypeI8: + return 1 + case TensorTypeI16: + return 2 + case TensorTypeI32: + return 4 + case TensorTypeI64: + return 8 + case TensorTypeF64: + return 8 + case tensorTypeIQ1_M: + return tt.blockSize()/8 + tt.blockSize()/16 + tt.blockSize()/32 + case TensorTypeBF16: + return 2 + default: + return 0 + } +} + +func (tt TensorType) blockSize() int64 { + switch tt { + case TensorTypeF32, + TensorTypeF16, + TensorTypeI8, + TensorTypeI16, + TensorTypeI32, + TensorTypeI64, + TensorTypeF64, + TensorTypeBF16: + return 1 + case TensorTypeQ4_0, + TensorTypeQ4_1, + TensorTypeQ5_0, + TensorTypeQ5_1, + TensorTypeQ8_0, + TensorTypeQ8_1, + tensorTypeIQ4_NL: + return 32 + default: + return 256 + } +} + +func (tt TensorType) String() string { + switch tt { + case TensorTypeF32: + return "f32" + case TensorTypeF16: + return "f16" + case TensorTypeQ4_0: + return "q4_0" + case TensorTypeQ4_1: + return "q4_1" + case tensorTypeQ4_2: + return "q4_2" + case tensorTypeQ4_3: + return "q4_3" + case TensorTypeQ5_0: + return "q5_0" + case TensorTypeQ5_1: + return "q5_1" + case TensorTypeQ8_0: + return "q8_0" + case TensorTypeQ8_1: + return "q8_1" + case TensorTypeQ2_K: + return "q2_k" + case TensorTypeQ3_K: + return "q3_k" + case TensorTypeQ4_K: + return "q4_k" + case TensorTypeQ5_K: + return "q5_k" + case TensorTypeQ6_K: + return "q6_k" + case TensorTypeQ8_K: + return "q8_k" + case tensorTypeIQ2_XXS: + return "iq2_xxs" + case tensorTypeIQ2_XS: + return "iq2_xs" + case tensorTypeIQ3_XXS: + return "iq3_xxs" + case tensorTypeIQ1_S: + return "iq1_s" + case tensorTypeIQ4_NL: + return "iq4_nl" + case tensorTypeIQ3_S: + return "iq3_s" + case tensorTypeIQ2_S: + return "iq2_s" + case tensorTypeIQ4_XS: + return "iq4_xs" + case TensorTypeI8: + return "i8" + case TensorTypeI16: + return "i16" + case TensorTypeI32: + return "i32" + case TensorTypeI64: + return "i64" + case TensorTypeF64: + return "f64" + case tensorTypeIQ1_M: + return "iq1_m" + case TensorTypeBF16: + return "bf16" + case tensorTypeQ4_0_4_4: + return "q4_0_4_4" + case tensorTypeQ4_0_4_8: + return "q4_0_4_8" + case tensorTypeQ4_0_8_8: + return "q4_0_8_8" + case tensorTypeTQ1_0: + return "tq1_0" + case tensorTypeTQ2_0: + return "tq2_0" + case tensorTypeIQ4_NL_4_4: + return "iq4_nl_4_4" + case tensorTypeIQ4_NL_4_8: + return "iq4_nl_4_8" + case tensorTypeIQ4_NL_8_8: + return "iq4_nl_8_8" + default: + return "unknown" + } +} + +func (tt TensorType) LogValue() slog.Value { + return slog.GroupValue( + slog.Uint64("value", uint64(tt)), + slog.String("name", strings.ToUpper(tt.String())), + slog.Int64("size", tt.typeSize()), + slog.Int64("block_size", tt.blockSize()), + slog.Float64("num_bytes", tt.NumBytes()), + ) +} diff --git a/go.mod b/go.mod index 283286b7d..6de5959be 100644 --- a/go.mod +++ b/go.mod @@ -19,7 +19,7 @@ require ( github.com/d4l3k/go-bfloat16 v0.0.0-20211005043715-690c3bdd05f1 github.com/dlclark/regexp2 v1.11.4 github.com/emirpasic/gods/v2 v2.0.0-alpha - github.com/google/go-cmp v0.6.0 + github.com/google/go-cmp v0.7.0 github.com/mattn/go-runewidth v0.0.14 github.com/nlpodyssey/gopickle v0.3.0 github.com/pdevine/tensor v0.0.0-20240510204454-f88f4562727c diff --git a/go.sum b/go.sum index 5755616f6..c0ab53aab 100644 --- a/go.sum +++ b/go.sum @@ -112,8 +112,8 @@ github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/ github.com/google/go-cmp v0.5.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= -github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= -github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= +github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= diff --git a/server/images.go b/server/images.go index d6cceff4c..38505cc51 100644 --- a/server/images.go +++ b/server/images.go @@ -23,7 +23,7 @@ import ( "github.com/ollama/ollama/api" "github.com/ollama/ollama/envconfig" - "github.com/ollama/ollama/fs/ggml" + "github.com/ollama/ollama/fs/gguf" "github.com/ollama/ollama/parser" "github.com/ollama/ollama/template" "github.com/ollama/ollama/thinking" @@ -73,22 +73,18 @@ func (m *Model) Capabilities() []model.Capability { capabilities := []model.Capability{} // Check for completion capability - r, err := os.Open(m.ModelPath) + f, err := gguf.Open(m.ModelPath) if err == nil { - defer r.Close() + defer f.Close() - f, err := ggml.Decode(r, 1024) - if err == nil { - if _, ok := f.KV()[fmt.Sprintf("%s.pooling_type", f.KV().Architecture())]; ok { - capabilities = append(capabilities, model.CapabilityEmbedding) - } else { - capabilities = append(capabilities, model.CapabilityCompletion) - } - if _, ok := f.KV()[fmt.Sprintf("%s.vision.block_count", f.KV().Architecture())]; ok { - capabilities = append(capabilities, model.CapabilityVision) - } + if f.KeyValue("pooling_type").Valid() { + capabilities = append(capabilities, model.CapabilityEmbedding) } else { - slog.Error("couldn't decode ggml", "error", err) + // If no embedding is specified, we assume the model supports completion + capabilities = append(capabilities, model.CapabilityCompletion) + } + if f.KeyValue("vision.block_count").Valid() { + capabilities = append(capabilities, model.CapabilityVision) } } else { slog.Error("couldn't open model file", "error", err) diff --git a/server/images_test.go b/server/images_test.go index 363b298e1..a2fba8d98 100644 --- a/server/images_test.go +++ b/server/images_test.go @@ -1,123 +1,42 @@ package server import ( - "bytes" - "encoding/binary" - "errors" - "os" - "path/filepath" "strings" "testing" + "github.com/ollama/ollama/fs/ggml" "github.com/ollama/ollama/template" "github.com/ollama/ollama/types/model" ) -// Constants for GGUF magic bytes and version -var ( - ggufMagic = []byte{0x47, 0x47, 0x55, 0x46} // "GGUF" - ggufVer = uint32(3) // Version 3 -) - -// Helper function to create mock GGUF data -func createMockGGUFData(architecture string, vision bool) []byte { - var buf bytes.Buffer - - // Write GGUF header - buf.Write(ggufMagic) - binary.Write(&buf, binary.LittleEndian, ggufVer) - - // Write tensor count (0 for our test) - var numTensors uint64 = 0 - binary.Write(&buf, binary.LittleEndian, numTensors) - - // Calculate number of metadata entries - numMetaEntries := uint64(1) // architecture entry - if vision { - numMetaEntries++ - } - // Add embedding entry if architecture is "bert" - if architecture == "bert" { - numMetaEntries++ - } - binary.Write(&buf, binary.LittleEndian, numMetaEntries) - - // Write architecture metadata - archKey := "general.architecture" - keyLen := uint64(len(archKey)) - binary.Write(&buf, binary.LittleEndian, keyLen) - buf.WriteString(archKey) - - // String type (8) - var strType uint32 = 8 - binary.Write(&buf, binary.LittleEndian, strType) - - // String length - strLen := uint64(len(architecture)) - binary.Write(&buf, binary.LittleEndian, strLen) - buf.WriteString(architecture) - - if vision { - visionKey := architecture + ".vision.block_count" - keyLen = uint64(len(visionKey)) - binary.Write(&buf, binary.LittleEndian, keyLen) - buf.WriteString(visionKey) - - // uint32 type (4) - var uint32Type uint32 = 4 - binary.Write(&buf, binary.LittleEndian, uint32Type) - - // uint32 value (1) - var countVal uint32 = 1 - binary.Write(&buf, binary.LittleEndian, countVal) - } - // Write embedding metadata if architecture is "bert" - if architecture == "bert" { - poolKey := architecture + ".pooling_type" - keyLen = uint64(len(poolKey)) - binary.Write(&buf, binary.LittleEndian, keyLen) - buf.WriteString(poolKey) - - // uint32 type (4) - var uint32Type uint32 = 4 - binary.Write(&buf, binary.LittleEndian, uint32Type) - - // uint32 value (1) - var poolingVal uint32 = 1 - binary.Write(&buf, binary.LittleEndian, poolingVal) - } - - return buf.Bytes() -} - func TestModelCapabilities(t *testing.T) { - // Create a temporary directory for test files - tempDir := t.TempDir() + // Create completion model (llama architecture without vision) + completionModelPath, _ := createBinFile(t, ggml.KV{ + "general.architecture": "llama", + }, []*ggml.Tensor{}) - // Create different types of mock model files - completionModelPath := filepath.Join(tempDir, "model.bin") - visionModelPath := filepath.Join(tempDir, "vision_model.bin") - embeddingModelPath := filepath.Join(tempDir, "embedding_model.bin") - // Create a simple model file for tests that don't depend on GGUF content - simpleModelPath := filepath.Join(tempDir, "simple_model.bin") + // Create vision model (llama architecture with vision block count) + visionModelPath, _ := createBinFile(t, ggml.KV{ + "general.architecture": "llama", + "llama.vision.block_count": uint32(1), + }, []*ggml.Tensor{}) - if err := errors.Join( - os.WriteFile(completionModelPath, createMockGGUFData("llama", false), 0o644), - os.WriteFile(visionModelPath, createMockGGUFData("llama", true), 0o644), - os.WriteFile(embeddingModelPath, createMockGGUFData("bert", false), 0o644), - os.WriteFile(simpleModelPath, []byte("dummy model data"), 0o644), - ); err != nil { - t.Fatalf("Failed to create model files: %v", err) - } + // Create embedding model (bert architecture with pooling type) + embeddingModelPath, _ := createBinFile(t, ggml.KV{ + "general.architecture": "bert", + "bert.pooling_type": uint32(1), + }, []*ggml.Tensor{}) toolsInsertTemplate, err := template.Parse("{{ .prompt }}{{ if .tools }}{{ .tools }}{{ end }}{{ if .suffix }}{{ .suffix }}{{ end }}") if err != nil { t.Fatalf("Failed to parse template: %v", err) } + chatTemplate, err := template.Parse("{{ .prompt }}") if err != nil { t.Fatalf("Failed to parse template: %v", err) } + toolsTemplate, err := template.Parse("{{ .prompt }}{{ if .tools }}{{ .tools }}{{ end }}") if err != nil { t.Fatalf("Failed to parse template: %v", err) @@ -145,21 +64,13 @@ func TestModelCapabilities(t *testing.T) { }, expectedCaps: []model.Capability{model.CapabilityCompletion, model.CapabilityTools, model.CapabilityInsert}, }, - { - name: "model with tools and insert capability", - model: Model{ - ModelPath: simpleModelPath, - Template: toolsInsertTemplate, - }, - expectedCaps: []model.Capability{model.CapabilityTools, model.CapabilityInsert}, - }, { name: "model with tools capability", model: Model{ - ModelPath: simpleModelPath, + ModelPath: completionModelPath, Template: toolsTemplate, }, - expectedCaps: []model.Capability{model.CapabilityTools}, + expectedCaps: []model.Capability{model.CapabilityCompletion, model.CapabilityTools}, }, { name: "model with vision capability", @@ -224,29 +135,33 @@ func TestModelCapabilities(t *testing.T) { } func TestModelCheckCapabilities(t *testing.T) { - // Create a temporary directory for test files - tempDir := t.TempDir() + // Create simple model file for tests that don't depend on GGUF content + completionModelPath, _ := createBinFile(t, ggml.KV{ + "general.architecture": "llama", + }, []*ggml.Tensor{}) - visionModelPath := filepath.Join(tempDir, "vision_model.bin") - simpleModelPath := filepath.Join(tempDir, "model.bin") - embeddingModelPath := filepath.Join(tempDir, "embedding_model.bin") + // Create vision model (llama architecture with vision block count) + visionModelPath, _ := createBinFile(t, ggml.KV{ + "general.architecture": "llama", + "llama.vision.block_count": uint32(1), + }, []*ggml.Tensor{}) - if err := errors.Join( - os.WriteFile(simpleModelPath, []byte("dummy model data"), 0o644), - os.WriteFile(visionModelPath, createMockGGUFData("llama", true), 0o644), - os.WriteFile(embeddingModelPath, createMockGGUFData("bert", false), 0o644), - ); err != nil { - t.Fatalf("Failed to create model files: %v", err) - } + // Create embedding model (bert architecture with pooling type) + embeddingModelPath, _ := createBinFile(t, ggml.KV{ + "general.architecture": "bert", + "bert.pooling_type": uint32(1), + }, []*ggml.Tensor{}) toolsInsertTemplate, err := template.Parse("{{ .prompt }}{{ if .tools }}{{ .tools }}{{ end }}{{ if .suffix }}{{ .suffix }}{{ end }}") if err != nil { t.Fatalf("Failed to parse template: %v", err) } + chatTemplate, err := template.Parse("{{ .prompt }}") if err != nil { t.Fatalf("Failed to parse template: %v", err) } + toolsTemplate, err := template.Parse("{{ .prompt }}{{ if .tools }}{{ .tools }}{{ end }}") if err != nil { t.Fatalf("Failed to parse template: %v", err) @@ -261,7 +176,7 @@ func TestModelCheckCapabilities(t *testing.T) { { name: "completion model without tools capability", model: Model{ - ModelPath: simpleModelPath, + ModelPath: completionModelPath, Template: chatTemplate, }, checkCaps: []model.Capability{model.CapabilityTools}, @@ -270,7 +185,7 @@ func TestModelCheckCapabilities(t *testing.T) { { name: "model with all needed capabilities", model: Model{ - ModelPath: simpleModelPath, + ModelPath: completionModelPath, Template: toolsInsertTemplate, }, checkCaps: []model.Capability{model.CapabilityTools, model.CapabilityInsert}, @@ -278,7 +193,7 @@ func TestModelCheckCapabilities(t *testing.T) { { name: "model missing insert capability", model: Model{ - ModelPath: simpleModelPath, + ModelPath: completionModelPath, Template: toolsTemplate, }, checkCaps: []model.Capability{model.CapabilityInsert}, @@ -287,7 +202,7 @@ func TestModelCheckCapabilities(t *testing.T) { { name: "model missing vision capability", model: Model{ - ModelPath: simpleModelPath, + ModelPath: completionModelPath, Template: toolsTemplate, }, checkCaps: []model.Capability{model.CapabilityVision}, @@ -312,7 +227,7 @@ func TestModelCheckCapabilities(t *testing.T) { { name: "unknown capability", model: Model{ - ModelPath: simpleModelPath, + ModelPath: completionModelPath, Template: chatTemplate, }, checkCaps: []model.Capability{"unknown"}, diff --git a/server/quantization_test.go b/server/quantization_test.go index 4f717c2c2..8b726c836 100644 --- a/server/quantization_test.go +++ b/server/quantization_test.go @@ -257,16 +257,8 @@ func TestQuantizeModel(t *testing.T) { for _, tt := range cases { t.Run(tt.name, func(t *testing.T) { - f, err := os.CreateTemp(t.TempDir(), tt.name) - if err != nil { - t.Fatal(err.Error()) - } - defer f.Close() - err = fsggml.WriteGGUF(f, tt.kv, tt.tensors) - if err != nil { - t.Fatalf("failed to create initial model: %s", err) - } - fp, err := os.Open(f.Name()) + p, _ := createBinFile(t, tt.kv, tt.tensors) + fp, err := os.Open(p) if err != nil { t.Fatal(err.Error()) } diff --git a/server/sched_test.go b/server/sched_test.go index 01fb9a703..3892fbbab 100644 --- a/server/sched_test.go +++ b/server/sched_test.go @@ -112,11 +112,7 @@ func newScenarioRequest(t *testing.T, ctx context.Context, modelName string, est b.ctx, b.ctxDone = context.WithCancel(ctx) t.Helper() - f, err := os.CreateTemp(t.TempDir(), modelName) - require.NoError(t, err) - defer f.Close() - - require.NoError(t, ggml.WriteGGUF(f, ggml.KV{ + p, _ := createBinFile(t, ggml.KV{ "general.architecture": "llama", "llama.context_length": uint32(32), "llama.embedding_length": uint32(4096), @@ -129,14 +125,14 @@ func newScenarioRequest(t *testing.T, ctx context.Context, modelName string, est }, []*ggml.Tensor{ {Name: "blk.0.attn.weight", Kind: uint32(0), Offset: uint64(0), Shape: []uint64{1, 1, 1, 1}, WriterTo: bytes.NewReader(make([]byte, 32))}, {Name: "output.weight", Kind: uint32(0), Offset: uint64(0), Shape: []uint64{1, 1, 1, 1}, WriterTo: bytes.NewReader(make([]byte, 32))}, - })) - require.NoError(t, err) - - fname := f.Name() - model := &Model{Name: modelName, ModelPath: fname} - b.f, err = llm.LoadModel(model.ModelPath, 0) - require.NoError(t, err) + }) + model := &Model{Name: modelName, ModelPath: p} + f, err := llm.LoadModel(model.ModelPath, 0) + if err != nil { + t.Fatal(err) + } + b.f = f if duration == nil { duration = &api.Duration{Duration: 5 * time.Millisecond} }