mirror of
https://github.com/jmorganca/ollama
synced 2025-10-06 00:32:49 +02:00
model: handle multiple eos tokens (#10577)
* get eos_token_id from generation_config.json * refactor * include both ids and strings in trace * comments * remove special case for gemma3 special vocab (#10743)
This commit is contained in:
@@ -53,8 +53,11 @@ func (ModelParameters) KV(t *Tokenizer) ggml.KV {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for _, sv := range t.SpecialVocabulary {
|
for _, sv := range t.SpecialVocabulary {
|
||||||
kv[fmt.Sprintf("tokenizer.ggml.%s_token_id", sv.Key())] = uint32(sv.ID)
|
|
||||||
kv[fmt.Sprintf("tokenizer.ggml.add_%s_token", sv.Key())] = sv.AddToken
|
kv[fmt.Sprintf("tokenizer.ggml.add_%s_token", sv.Key())] = sv.AddToken
|
||||||
|
kv[fmt.Sprintf("tokenizer.ggml.%s_token_id", sv.Key())] = uint32(sv.ID)
|
||||||
|
if len(sv.IDs) > 0 {
|
||||||
|
kv[fmt.Sprintf("tokenizer.ggml.%s_token_ids", sv.Key())] = sv.IDs
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return kv
|
return kv
|
||||||
|
@@ -110,6 +110,7 @@ func parseTokenizer(fsys fs.FS, specialTokenTypes []string) (*Tokenizer, error)
|
|||||||
}
|
}
|
||||||
|
|
||||||
if f, err := fsys.Open("tokenizer_config.json"); errors.Is(err, os.ErrNotExist) {
|
if f, err := fsys.Open("tokenizer_config.json"); errors.Is(err, os.ErrNotExist) {
|
||||||
|
// noop
|
||||||
} else if err != nil {
|
} else if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
} else {
|
} else {
|
||||||
@@ -171,6 +172,34 @@ func parseTokenizer(fsys fs.FS, specialTokenTypes []string) (*Tokenizer, error)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if f, err := fsys.Open("generation_config.json"); errors.Is(err, os.ErrNotExist) {
|
||||||
|
} else if err != nil {
|
||||||
|
return nil, err
|
||||||
|
} else {
|
||||||
|
defer f.Close()
|
||||||
|
|
||||||
|
var p map[string]json.RawMessage
|
||||||
|
if err := json.NewDecoder(f).Decode(&p); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, st := range specialTokenTypes {
|
||||||
|
if bts, ok := p[fmt.Sprintf("%s_token_id", st)]; ok {
|
||||||
|
var ids []int32
|
||||||
|
if err := json.Unmarshal(bts, &ids); err != nil {
|
||||||
|
// value is not a list so the existing ID is used
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if i := slices.IndexFunc(t.SpecialVocabulary, func(sv *SpecialVocabulary) bool {
|
||||||
|
return sv.Type == st
|
||||||
|
}); i >= 0 {
|
||||||
|
t.SpecialVocabulary[i].IDs = ids
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return t, nil
|
return t, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -280,6 +309,9 @@ type SpecialVocabulary struct {
|
|||||||
ID int
|
ID int
|
||||||
Content string
|
Content string
|
||||||
AddToken bool
|
AddToken bool
|
||||||
|
|
||||||
|
// IDs is populated by generation_config.json
|
||||||
|
IDs []int32
|
||||||
}
|
}
|
||||||
|
|
||||||
func (sv SpecialVocabulary) Key() string {
|
func (sv SpecialVocabulary) Key() string {
|
||||||
|
@@ -247,6 +247,67 @@ func TestParseTokenizer(t *testing.T) {
|
|||||||
Pre: "default",
|
Pre: "default",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "generation config eos token ids",
|
||||||
|
fsys: createTokenizerFS(t, t.TempDir(), map[string]io.Reader{
|
||||||
|
"tokenizer.json": strings.NewReader(`{
|
||||||
|
"added_tokens": [
|
||||||
|
{
|
||||||
|
"id": 0,
|
||||||
|
"content": "<bos>",
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1,
|
||||||
|
"content": "<eos>",
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2,
|
||||||
|
"content": "<eot>",
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3,
|
||||||
|
"content": "<eom>",
|
||||||
|
"special": true
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"model": {
|
||||||
|
"vocab": {
|
||||||
|
"<bos>": 0,
|
||||||
|
"<eos>": 1,
|
||||||
|
"<eot>": 2,
|
||||||
|
"<eom>": 3
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}`),
|
||||||
|
"tokenizer_config.json": strings.NewReader(`{
|
||||||
|
"add_bos_token": true,
|
||||||
|
"add_eos_token": false,
|
||||||
|
"bos_token": "<bos>",
|
||||||
|
"eos_token": "<eos>"
|
||||||
|
}`),
|
||||||
|
"generation_config.json": strings.NewReader(`{
|
||||||
|
"bos_token_id": 0,
|
||||||
|
"eos_token_id": [1, 2, 3]
|
||||||
|
}`),
|
||||||
|
}),
|
||||||
|
specialTokenTypes: []string{"pad", "eos", "bos", "unk"},
|
||||||
|
want: &Tokenizer{
|
||||||
|
Vocabulary: &Vocabulary{
|
||||||
|
Model: "gpt2",
|
||||||
|
Tokens: []string{"<bos>", "<eos>", "<eot>", "<eom>"},
|
||||||
|
Scores: []float32{0, 1, 2, 3},
|
||||||
|
Types: []int32{3, 3, 3, 3},
|
||||||
|
},
|
||||||
|
SpecialVocabulary: []*SpecialVocabulary{
|
||||||
|
{Type: "eos", Content: "<eos>", ID: 1, IDs: []int32{1, 2, 3}, AddToken: false},
|
||||||
|
{Type: "bos", Content: "<bos>", ID: 0, AddToken: true},
|
||||||
|
},
|
||||||
|
Pre: "default",
|
||||||
|
},
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tt := range cases {
|
for _, tt := range cases {
|
||||||
|
@@ -602,7 +602,7 @@ type Grammar struct {
|
|||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewGrammar(grammar string, vocabIds []uint32, vocabValues []string, eogTokens []uint32) *Grammar {
|
func NewGrammar(grammar string, vocabIds []uint32, vocabValues []string, eogTokens []int32) *Grammar {
|
||||||
cGrammar := C.CString(grammar)
|
cGrammar := C.CString(grammar)
|
||||||
defer C.free(unsafe.Pointer(cGrammar))
|
defer C.free(unsafe.Pointer(cGrammar))
|
||||||
|
|
||||||
@@ -622,7 +622,7 @@ func NewGrammar(grammar string, vocabIds []uint32, vocabValues []string, eogToke
|
|||||||
cEogTokens[i] = C.uint32_t(token)
|
cEogTokens[i] = C.uint32_t(token)
|
||||||
}
|
}
|
||||||
|
|
||||||
g := C.grammar_init(cGrammar, (*C.uint32_t)(unsafe.Pointer(&cTokens[0])), C.size_t(len(cTokens)), (**C.char)(unsafe.Pointer(&cPieces[0])), (*C.uint32_t)(unsafe.Pointer(&cEogTokens[0])), C.size_t(len(cEogTokens)))
|
g := C.grammar_init(cGrammar, unsafe.SliceData(cTokens), C.size_t(len(cTokens)), unsafe.SliceData(cPieces), unsafe.SliceData(cEogTokens), C.size_t(len(cEogTokens)))
|
||||||
if g == nil {
|
if g == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@@ -5,116 +5,13 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"iter"
|
"iter"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"slices"
|
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
|
||||||
|
|
||||||
"github.com/dlclark/regexp2"
|
"github.com/dlclark/regexp2"
|
||||||
heap "github.com/emirpasic/gods/v2/trees/binaryheap"
|
heap "github.com/emirpasic/gods/v2/trees/binaryheap"
|
||||||
"github.com/ollama/ollama/logutil"
|
"github.com/ollama/ollama/logutil"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Special int32
|
|
||||||
|
|
||||||
const (
|
|
||||||
SpecialBOS Special = iota
|
|
||||||
SpecialEOS
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
TOKEN_TYPE_NORMAL = iota + 1
|
|
||||||
TOKEN_TYPE_UNKNOWN
|
|
||||||
TOKEN_TYPE_CONTROL
|
|
||||||
TOKEN_TYPE_USER_DEFINED
|
|
||||||
TOKEN_TYPE_UNUSED
|
|
||||||
TOKEN_TYPE_BYTE
|
|
||||||
)
|
|
||||||
|
|
||||||
type TextProcessor interface {
|
|
||||||
Encode(s string, addSpecial bool) ([]int32, error)
|
|
||||||
Decode([]int32) (string, error)
|
|
||||||
Is(int32, Special) bool
|
|
||||||
Vocabulary() *Vocabulary
|
|
||||||
}
|
|
||||||
|
|
||||||
type Vocabulary struct {
|
|
||||||
Values []string
|
|
||||||
Types []int32
|
|
||||||
Scores []float32
|
|
||||||
Merges []string
|
|
||||||
|
|
||||||
BOS, EOS, EOT int32
|
|
||||||
AddBOS, AddEOS, AddEOT bool
|
|
||||||
|
|
||||||
specialOnce sync.Once
|
|
||||||
special []string
|
|
||||||
|
|
||||||
valuesOnce sync.Once
|
|
||||||
values map[string]int32
|
|
||||||
|
|
||||||
mergeOnce sync.Once
|
|
||||||
merge map[string]int32
|
|
||||||
}
|
|
||||||
|
|
||||||
func (v *Vocabulary) Is(id int32, special Special) bool {
|
|
||||||
switch special {
|
|
||||||
case SpecialBOS:
|
|
||||||
return id == v.BOS
|
|
||||||
case SpecialEOS:
|
|
||||||
return id == v.EOS || id == v.EOT
|
|
||||||
default:
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (v *Vocabulary) Encode(s string) int32 {
|
|
||||||
v.valuesOnce.Do(func() {
|
|
||||||
v.values = make(map[string]int32, len(v.Values))
|
|
||||||
for i, value := range v.Values {
|
|
||||||
v.values[value] = int32(i)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
if id, ok := v.values[s]; ok {
|
|
||||||
return id
|
|
||||||
}
|
|
||||||
|
|
||||||
return -1
|
|
||||||
}
|
|
||||||
|
|
||||||
func (v *Vocabulary) Decode(id int32) string {
|
|
||||||
return v.Values[id]
|
|
||||||
}
|
|
||||||
|
|
||||||
func (v *Vocabulary) SpecialVocabulary() []string {
|
|
||||||
v.specialOnce.Do(func() {
|
|
||||||
for i := range v.Values {
|
|
||||||
if slices.Contains([]int{105, 106}, i) {
|
|
||||||
v.special = append(v.special, v.Values[i])
|
|
||||||
} else if v.Types[i] == TOKEN_TYPE_CONTROL {
|
|
||||||
v.special = append(v.special, v.Values[i])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
return v.special
|
|
||||||
}
|
|
||||||
|
|
||||||
func (v *Vocabulary) Merge(left, right string) int {
|
|
||||||
v.mergeOnce.Do(func() {
|
|
||||||
v.merge = make(map[string]int32, len(v.Merges))
|
|
||||||
for i, merge := range v.Merges {
|
|
||||||
v.merge[merge] = int32(i)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
if id, ok := v.merge[left+" "+right]; ok {
|
|
||||||
return int(id)
|
|
||||||
}
|
|
||||||
|
|
||||||
return -1
|
|
||||||
}
|
|
||||||
|
|
||||||
type BytePairEncoding struct {
|
type BytePairEncoding struct {
|
||||||
pre *regexp2.Regexp
|
pre *regexp2.Regexp
|
||||||
vocab *Vocabulary
|
vocab *Vocabulary
|
||||||
@@ -304,27 +201,12 @@ func (bpe BytePairEncoding) Encode(s string, addSpecial bool) ([]int32, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
slog.Log(context.TODO(), logutil.LevelTrace, "encoded", "string", s, "ids", ids)
|
||||||
|
|
||||||
if addSpecial && len(ids) > 0 {
|
if addSpecial && len(ids) > 0 {
|
||||||
if bpe.vocab.AddBOS {
|
ids = bpe.vocab.addSpecials(ids)
|
||||||
if ids[0] == bpe.vocab.BOS {
|
|
||||||
slog.Warn("adding bos token to prompt which already has it", "id", bpe.vocab.BOS)
|
|
||||||
}
|
|
||||||
|
|
||||||
slog.Debug("adding bos token to prompt", "id", bpe.vocab.BOS)
|
|
||||||
ids = append([]int32{bpe.vocab.BOS}, ids...)
|
|
||||||
}
|
|
||||||
|
|
||||||
if bpe.vocab.AddEOS {
|
|
||||||
if ids[len(ids)-1] == bpe.vocab.EOS {
|
|
||||||
slog.Warn("adding eos token to prompt which already has it", "id", bpe.vocab.EOS)
|
|
||||||
}
|
|
||||||
|
|
||||||
slog.Debug("adding eos token to prompt", "id", bpe.vocab.EOS)
|
|
||||||
ids = append(ids, bpe.vocab.EOS)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
slog.Log(context.TODO(), logutil.LevelTrace, "encoded", "ids", ids)
|
|
||||||
return ids, nil
|
return ids, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -352,6 +234,6 @@ func (bpe BytePairEncoding) Decode(ids []int32) (string, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
slog.Log(context.TODO(), logutil.LevelTrace, "decoded", "string", sb.String())
|
slog.Log(context.TODO(), logutil.LevelTrace, "decoded", "ids", ids, "string", sb.String())
|
||||||
return sb.String(), nil
|
return sb.String(), nil
|
||||||
}
|
}
|
@@ -43,10 +43,13 @@ func New(c fs.Config) (model.Model, error) {
|
|||||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||||
Scores: c.Floats("tokenizer.ggml.scores"),
|
Scores: c.Floats("tokenizer.ggml.scores"),
|
||||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||||
BOS: int32(c.Uint("tokenizer.ggml.bos_token_id")),
|
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
|
||||||
EOS: int32(c.Uint("tokenizer.ggml.eos_token_id")),
|
BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
|
||||||
// TODO: set EOT to EOS otherwise 0 will stop generation
|
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
|
||||||
EOT: int32(c.Uint("tokenizer.ggml.eos_token_id")),
|
EOS: append(
|
||||||
|
[]int32{int32(c.Uint("tokenizer.ggml.eos_token_id"))},
|
||||||
|
c.Ints("tokenizer.ggml.eos_token_ids")...,
|
||||||
|
),
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
Layers: make([]Layer, c.Uint("block_count")),
|
Layers: make([]Layer, c.Uint("block_count")),
|
||||||
|
@@ -60,12 +60,16 @@ func New(c fs.Config) (model.Model, error) {
|
|||||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||||
Scores: c.Floats("tokenizer.ggml.scores"),
|
Scores: c.Floats("tokenizer.ggml.scores"),
|
||||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||||
BOS: int32(c.Uint("tokenizer.ggml.bos_token_id")),
|
|
||||||
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
|
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
|
||||||
EOS: int32(1),
|
BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
|
||||||
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
|
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
|
||||||
EOT: int32(106),
|
EOS: append(
|
||||||
AddEOT: c.Bool("tokenizer.ggml.add_eot_token", false),
|
[]int32{
|
||||||
|
int32(c.Uint("tokenizer.ggml.eos_token_id")),
|
||||||
|
int32(c.Uint("tokenizer.ggml.eot_token_id", 106)),
|
||||||
|
},
|
||||||
|
c.Ints("tokenizer.ggml.eos_token_ids")...,
|
||||||
|
),
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
ImageProcessor: newImageProcessor(c),
|
ImageProcessor: newImageProcessor(c),
|
||||||
|
@@ -43,13 +43,13 @@ func New(c fs.Config) (model.Model, error) {
|
|||||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||||
Merges: c.Strings("tokenizer.ggml.merges"),
|
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||||
BOS: int32(c.Uint("tokenizer.ggml.bos_token_id")),
|
|
||||||
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
|
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
|
||||||
EOS: int32(c.Uint("tokenizer.ggml.eos_token_id")),
|
BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
|
||||||
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
|
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
|
||||||
// TODO: set EOT to EOS otherwise 0 will stop generation
|
EOS: append(
|
||||||
EOT: int32(c.Uint("tokenizer.ggml.eos_token_id")),
|
[]int32{int32(c.Uint("tokenizer.ggml.eos_token_id"))},
|
||||||
AddEOT: c.Bool("tokenizer.ggml.add_eos_token", false),
|
c.Ints("tokenizer.ggml.eos_token_ids")...,
|
||||||
|
),
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
Layers: make([]Layer, c.Uint("block_count")),
|
Layers: make([]Layer, c.Uint("block_count")),
|
||||||
|
@@ -40,13 +40,13 @@ func New(c fs.Config) (model.Model, error) {
|
|||||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||||
Merges: c.Strings("tokenizer.ggml.merges"),
|
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||||
BOS: int32(c.Uint("tokenizer.ggml.bos_token_id")),
|
|
||||||
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
|
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
|
||||||
EOS: int32(c.Uint("tokenizer.ggml.eos_token_id")),
|
BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
|
||||||
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
|
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
|
||||||
// TODO: set EOT to EOS otherwise 0 will stop generation
|
EOS: append(
|
||||||
EOT: int32(c.Uint("tokenizer.ggml.eos_token_id")),
|
[]int32{int32(c.Uint("tokenizer.ggml.eos_token_id"))},
|
||||||
AddEOT: c.Bool("tokenizer.ggml.add_eos_token", false),
|
c.Ints("tokenizer.ggml.eos_token_ids")...,
|
||||||
|
),
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
ImageProcessor: newImageProcessor(c),
|
ImageProcessor: newImageProcessor(c),
|
||||||
|
@@ -37,25 +37,25 @@ func New(c fs.Config) (model.Model, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
m := &Model{
|
m := &Model{
|
||||||
TextModel: textModel,
|
|
||||||
VisionModel: newVisionModel(c),
|
|
||||||
ImageProcessor: newImageProcessor(c),
|
|
||||||
MultiModalProjector: newMultiModalProjector(c),
|
|
||||||
BytePairEncoding: model.NewBytePairEncoding(
|
BytePairEncoding: model.NewBytePairEncoding(
|
||||||
c.String("tokenizer.ggml.pretokenizer", `[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]*[\p{Ll}\p{Lm}\p{Lo}\p{M}]+|[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]+[\p{Ll}\p{Lm}\p{Lo}\p{M}]*|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n/]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
|
c.String("tokenizer.ggml.pretokenizer", `[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]*[\p{Ll}\p{Lm}\p{Lo}\p{M}]+|[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]+[\p{Ll}\p{Lm}\p{Lo}\p{M}]*|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n/]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
|
||||||
&model.Vocabulary{
|
&model.Vocabulary{
|
||||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||||
Merges: c.Strings("tokenizer.ggml.merges"),
|
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||||
BOS: int32(c.Uint("tokenizer.ggml.bos_token_id", 1)),
|
|
||||||
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
|
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
|
||||||
EOS: int32(c.Uint("tokenizer.ggml.eos_token_id", 2)),
|
BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
|
||||||
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
|
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
|
||||||
// TODO: set EOT to EOS otherwise 0 will stop generation
|
EOS: append(
|
||||||
EOT: int32(c.Uint("tokenizer.ggml.eos_token_id")),
|
[]int32{int32(c.Uint("tokenizer.ggml.eos_token_id"))},
|
||||||
AddEOT: c.Bool("tokenizer.ggml.add_eos_token", false),
|
c.Ints("tokenizer.ggml.eos_token_ids")...,
|
||||||
|
),
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
|
TextModel: textModel,
|
||||||
|
VisionModel: newVisionModel(c),
|
||||||
|
ImageProcessor: newImageProcessor(c),
|
||||||
|
MultiModalProjector: newMultiModalProjector(c),
|
||||||
}
|
}
|
||||||
|
|
||||||
m.Cache = kvcache.NewCausalCache(m.TextModel.Shift)
|
m.Cache = kvcache.NewCausalCache(m.TextModel.Shift)
|
||||||
|
@@ -38,13 +38,13 @@ func New(c fs.Config) (model.Model, error) {
|
|||||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||||
Merges: c.Strings("tokenizer.ggml.merges"),
|
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||||
BOS: int32(c.Uint("tokenizer.ggml.bos_token_id")),
|
|
||||||
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
|
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
|
||||||
EOS: int32(c.Uint("tokenizer.ggml.eos_token_id")),
|
BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
|
||||||
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
|
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
|
||||||
// TODO: set EOT to EOS otherwise 0 will stop generation
|
EOS: append(
|
||||||
EOT: int32(c.Uint("tokenizer.ggml.eos_token_id")),
|
[]int32{int32(c.Uint("tokenizer.ggml.eos_token_id"))},
|
||||||
AddEOT: c.Bool("tokenizer.ggml.add_eos_token", false),
|
c.Ints("tokenizer.ggml.eos_token_ids")...,
|
||||||
|
),
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
ImageProcessor: newImageProcessor(c),
|
ImageProcessor: newImageProcessor(c),
|
||||||
|
@@ -34,12 +34,13 @@ func New(c fs.Config) (model.Model, error) {
|
|||||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||||
Merges: c.Strings("tokenizer.ggml.merges"),
|
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||||
BOS: int32(c.Uint("tokenizer.ggml.bos_token_id")),
|
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
|
||||||
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", false),
|
BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
|
||||||
EOS: int32(c.Uint("tokenizer.ggml.eos_token_id")),
|
|
||||||
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
|
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
|
||||||
EOT: int32(c.Uint("tokenizer.ggml.eos_token_id")),
|
EOS: append(
|
||||||
AddEOT: c.Bool("tokenizer.ggml.add_eos_token", false),
|
[]int32{int32(c.Uint("tokenizer.ggml.eos_token_id"))},
|
||||||
|
c.Ints("tokenizer.ggml.eos_token_ids")...,
|
||||||
|
),
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
TextModel: NewTextModel(c),
|
TextModel: NewTextModel(c),
|
||||||
|
@@ -182,27 +182,12 @@ func (spm SentencePieceModel) Encode(s string, addSpecial bool) ([]int32, error)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
slog.Log(context.TODO(), logutil.LevelTrace, "encoded", "string", s, "ids", ids)
|
||||||
|
|
||||||
if addSpecial && len(ids) > 0 {
|
if addSpecial && len(ids) > 0 {
|
||||||
if spm.vocab.AddBOS {
|
ids = spm.vocab.addSpecials(ids)
|
||||||
if ids[0] == spm.vocab.BOS {
|
|
||||||
slog.Warn("adding bos token to prompt which already has it", "id", spm.vocab.BOS)
|
|
||||||
}
|
|
||||||
|
|
||||||
slog.Debug("adding bos token to prompt", "id", spm.vocab.BOS)
|
|
||||||
ids = append([]int32{spm.vocab.BOS}, ids...)
|
|
||||||
}
|
|
||||||
|
|
||||||
if spm.vocab.AddEOS {
|
|
||||||
if ids[len(ids)-1] == spm.vocab.EOS {
|
|
||||||
slog.Warn("adding eos token to prompt which already has it", "id", spm.vocab.EOS)
|
|
||||||
}
|
|
||||||
|
|
||||||
slog.Debug("adding eos token to prompt", "id", spm.vocab.EOS)
|
|
||||||
ids = append(ids, spm.vocab.EOS)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
slog.Log(context.TODO(), logutil.LevelTrace, "encoded", "ids", ids)
|
|
||||||
return ids, nil
|
return ids, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -261,6 +246,6 @@ func (spm SentencePieceModel) Decode(ids []int32) (string, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
slog.Log(context.TODO(), logutil.LevelTrace, "decoded", "string", sb.String())
|
slog.Log(context.TODO(), logutil.LevelTrace, "decoded", "ids", ids, "string", sb.String())
|
||||||
return sb.String(), nil
|
return sb.String(), nil
|
||||||
}
|
}
|
17
model/textprocessor.go
Normal file
17
model/textprocessor.go
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
package model
|
||||||
|
|
||||||
|
const (
|
||||||
|
TOKEN_TYPE_NORMAL = iota + 1
|
||||||
|
TOKEN_TYPE_UNKNOWN
|
||||||
|
TOKEN_TYPE_CONTROL
|
||||||
|
TOKEN_TYPE_USER_DEFINED
|
||||||
|
TOKEN_TYPE_UNUSED
|
||||||
|
TOKEN_TYPE_BYTE
|
||||||
|
)
|
||||||
|
|
||||||
|
type TextProcessor interface {
|
||||||
|
Encode(s string, addSpecial bool) ([]int32, error)
|
||||||
|
Decode([]int32) (string, error)
|
||||||
|
Is(int32, Special) bool
|
||||||
|
Vocabulary() *Vocabulary
|
||||||
|
}
|
112
model/vocabulary.go
Normal file
112
model/vocabulary.go
Normal file
@@ -0,0 +1,112 @@
|
|||||||
|
package model
|
||||||
|
|
||||||
|
import (
|
||||||
|
"log/slog"
|
||||||
|
"slices"
|
||||||
|
"sync"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Special int32
|
||||||
|
|
||||||
|
const (
|
||||||
|
SpecialBOS Special = iota
|
||||||
|
SpecialEOS
|
||||||
|
)
|
||||||
|
|
||||||
|
type Vocabulary struct {
|
||||||
|
Values []string
|
||||||
|
Types []int32
|
||||||
|
Scores []float32
|
||||||
|
Merges []string
|
||||||
|
|
||||||
|
BOS, EOS []int32
|
||||||
|
AddBOS, AddEOS bool
|
||||||
|
|
||||||
|
specialOnce sync.Once
|
||||||
|
special []string
|
||||||
|
|
||||||
|
valuesOnce sync.Once
|
||||||
|
values map[string]int32
|
||||||
|
|
||||||
|
mergeOnce sync.Once
|
||||||
|
merge map[string]int32
|
||||||
|
}
|
||||||
|
|
||||||
|
func (v *Vocabulary) Is(id int32, special Special) bool {
|
||||||
|
switch special {
|
||||||
|
case SpecialBOS:
|
||||||
|
return slices.Contains(v.BOS, id)
|
||||||
|
case SpecialEOS:
|
||||||
|
return slices.Contains(v.EOS, id)
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (v *Vocabulary) addSpecials(ids []int32) []int32 {
|
||||||
|
if v.AddBOS && len(v.BOS) > 0 {
|
||||||
|
if slices.Contains(v.BOS, ids[0]) {
|
||||||
|
slog.Warn("adding bos token to prompt which already has it", "id", v.BOS)
|
||||||
|
}
|
||||||
|
|
||||||
|
slog.Debug("adding bos token to prompt", "id", v.BOS)
|
||||||
|
ids = append([]int32{v.BOS[0]}, ids...)
|
||||||
|
}
|
||||||
|
|
||||||
|
if v.AddEOS && len(v.EOS) > 0 {
|
||||||
|
if slices.Contains(v.BOS, ids[len(ids)-1]) {
|
||||||
|
slog.Warn("adding eos token to prompt which already has it", "id", v.EOS)
|
||||||
|
}
|
||||||
|
|
||||||
|
slog.Debug("adding eos token to prompt", "id", v.EOS)
|
||||||
|
ids = append(ids, v.EOS[0])
|
||||||
|
}
|
||||||
|
|
||||||
|
return ids
|
||||||
|
}
|
||||||
|
|
||||||
|
func (v *Vocabulary) Encode(s string) int32 {
|
||||||
|
v.valuesOnce.Do(func() {
|
||||||
|
v.values = make(map[string]int32, len(v.Values))
|
||||||
|
for i, value := range v.Values {
|
||||||
|
v.values[value] = int32(i)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
if id, ok := v.values[s]; ok {
|
||||||
|
return id
|
||||||
|
}
|
||||||
|
|
||||||
|
return -1
|
||||||
|
}
|
||||||
|
|
||||||
|
func (v *Vocabulary) Decode(id int32) string {
|
||||||
|
return v.Values[id]
|
||||||
|
}
|
||||||
|
|
||||||
|
func (v *Vocabulary) SpecialVocabulary() []string {
|
||||||
|
v.specialOnce.Do(func() {
|
||||||
|
for i := range v.Values {
|
||||||
|
if v.Types[i] == TOKEN_TYPE_CONTROL {
|
||||||
|
v.special = append(v.special, v.Values[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
return v.special
|
||||||
|
}
|
||||||
|
|
||||||
|
func (v *Vocabulary) Merge(left, right string) int {
|
||||||
|
v.mergeOnce.Do(func() {
|
||||||
|
v.merge = make(map[string]int32, len(v.Merges))
|
||||||
|
for i, merge := range v.Merges {
|
||||||
|
v.merge[merge] = int32(i)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
if id, ok := v.merge[left+" "+right]; ok {
|
||||||
|
return int(id)
|
||||||
|
}
|
||||||
|
|
||||||
|
return -1
|
||||||
|
}
|
@@ -176,7 +176,7 @@ func NewGrammarSampler(model model.TextProcessor, grammarStr string) (*GrammarSa
|
|||||||
vocabIds[i] = uint32(i)
|
vocabIds[i] = uint32(i)
|
||||||
}
|
}
|
||||||
|
|
||||||
grammar := llama.NewGrammar(grammarStr, vocabIds, pieces, []uint32{uint32(model.Vocabulary().EOS), uint32(model.Vocabulary().EOT)})
|
grammar := llama.NewGrammar(grammarStr, vocabIds, pieces, model.Vocabulary().EOS)
|
||||||
if grammar == nil {
|
if grammar == nil {
|
||||||
return nil, errors.New("sample: failed to initialize grammar")
|
return nil, errors.New("sample: failed to initialize grammar")
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user