mirror of
https://github.com/jmorganca/ollama
synced 2025-10-06 00:32:49 +02:00
* test: improve scheduler/concurrency stress tests The scheduler test used to use approximate memory figures and would often over or under shoot a systems capcity leading to flaky test results. This should improve the reliability of this scenario by leveraging ps output to determinie exactly how many models it takes to trigger thrashing. The concurrency test is also refined to target num_parallel + 1 and handle timeouts better. With these refinements, TestMultiModelConcurrency was redundant * test: add parallel generate with history TestGenerateWithHistory will help verify caching and context are properly handled while making requests * test: focus embed tests on embedding models remove non-embedding models from the embedding tests
208 lines
5.7 KiB
Go
208 lines
5.7 KiB
Go
//go:build integration
|
|
|
|
package integration
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"log/slog"
|
|
"math"
|
|
"math/rand"
|
|
"os"
|
|
"strconv"
|
|
"sync"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/stretchr/testify/require"
|
|
|
|
"github.com/ollama/ollama/api"
|
|
"github.com/ollama/ollama/envconfig"
|
|
"github.com/ollama/ollama/format"
|
|
)
|
|
|
|
// Send multiple requests in parallel (concurrently) to a single model and ensure responses are expected
|
|
func TestConcurrentGenerate(t *testing.T) {
|
|
// Assumes all requests have the same model
|
|
req, resp := GenerateRequests()
|
|
numParallel := int(envconfig.NumParallel() + 1)
|
|
iterLimit := 3
|
|
|
|
softTimeout, hardTimeout := getTimeouts(t)
|
|
ctx, cancel := context.WithTimeout(context.Background(), hardTimeout)
|
|
defer cancel()
|
|
client, _, cleanup := InitServerConnection(ctx, t)
|
|
defer cleanup()
|
|
|
|
// Get the server running (if applicable) warm the model up with a single initial request
|
|
slog.Info("loading", "model", req[0].Model)
|
|
err := client.Generate(ctx,
|
|
&api.GenerateRequest{Model: req[0].Model, KeepAlive: &api.Duration{Duration: 10 * time.Second}},
|
|
func(response api.GenerateResponse) error { return nil },
|
|
)
|
|
if err != nil {
|
|
t.Fatalf("failed to load model %s: %s", req[0].Model, err)
|
|
}
|
|
|
|
var wg sync.WaitGroup
|
|
r := rand.New(rand.NewSource(0))
|
|
wg.Add(numParallel)
|
|
for i := range numParallel {
|
|
go func(i int) {
|
|
defer wg.Done()
|
|
for j := 0; j < iterLimit; j++ {
|
|
if time.Now().Sub(started) > softTimeout {
|
|
slog.Info("exceeded soft timeout, winding down test")
|
|
return
|
|
}
|
|
k := r.Int() % len(req)
|
|
slog.Info("Starting", "thread", i, "iter", j)
|
|
// On slower GPUs it can take a while to process the concurrent requests
|
|
// so we allow a much longer initial timeout
|
|
DoGenerate(ctx, t, client, req[k], resp[k], 120*time.Second, 20*time.Second)
|
|
}
|
|
}(i)
|
|
}
|
|
wg.Wait()
|
|
}
|
|
|
|
// Stress the scheduler and attempt to load more models than will fit to cause thrashing
|
|
// This test will always load at least 2 models even on CPU based systems
|
|
func TestMultiModelStress(t *testing.T) {
|
|
s := os.Getenv("OLLAMA_MAX_VRAM")
|
|
if s == "" {
|
|
s = "0"
|
|
}
|
|
|
|
maxVram, err := strconv.ParseUint(s, 10, 64)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
smallModels := []string{
|
|
"llama3.2:1b",
|
|
"qwen3:0.6b",
|
|
"gemma:2b",
|
|
"deepseek-r1:1.5b",
|
|
"starcoder2:3b",
|
|
}
|
|
mediumModels := []string{
|
|
"qwen3:8b",
|
|
"llama2",
|
|
"deepseek-r1:7b",
|
|
"mistral",
|
|
"dolphin-mistral",
|
|
"gemma:7b",
|
|
"codellama:7b",
|
|
}
|
|
|
|
var chosenModels []string
|
|
switch {
|
|
case maxVram < 10000*format.MebiByte:
|
|
slog.Info("selecting small models")
|
|
chosenModels = smallModels
|
|
default:
|
|
slog.Info("selecting medium models")
|
|
chosenModels = mediumModels
|
|
}
|
|
|
|
softTimeout, hardTimeout := getTimeouts(t)
|
|
ctx, cancel := context.WithTimeout(context.Background(), hardTimeout)
|
|
defer cancel()
|
|
client, _, cleanup := InitServerConnection(ctx, t)
|
|
defer cleanup()
|
|
|
|
// Make sure all the models are pulled before we get started
|
|
for _, model := range chosenModels {
|
|
require.NoError(t, PullIfMissing(ctx, client, model))
|
|
}
|
|
|
|
// Determine how many models we can load in parallel before we exceed VRAM
|
|
// The intent is to go 1 over what can fit so we force the scheduler to thrash
|
|
targetLoadCount := 0
|
|
slog.Info("Loading models to find how many can fit in VRAM before overflowing")
|
|
for i, model := range chosenModels {
|
|
req := &api.GenerateRequest{Model: model}
|
|
slog.Info("loading", "model", model)
|
|
err = client.Generate(ctx, req, func(response api.GenerateResponse) error { return nil })
|
|
if err != nil {
|
|
t.Fatalf("failed to load model %s: %s", model, err)
|
|
}
|
|
targetLoadCount++
|
|
if i > 0 {
|
|
models, err := client.ListRunning(ctx)
|
|
if err != nil {
|
|
t.Fatalf("failed to list running models: %s", err)
|
|
}
|
|
if len(models.Models) < targetLoadCount {
|
|
loaded := []string{}
|
|
for _, m := range models.Models {
|
|
loaded = append(loaded, m.Name)
|
|
}
|
|
slog.Info("found model load capacity", "target", targetLoadCount, "current", loaded, "chosen", chosenModels[:targetLoadCount])
|
|
break
|
|
}
|
|
}
|
|
}
|
|
if targetLoadCount == len(chosenModels) {
|
|
// TODO consider retrying the medium models
|
|
slog.Warn("all models being used without exceeding VRAM, set OLLAMA_MAX_VRAM so test can pick larger models")
|
|
}
|
|
|
|
r := rand.New(rand.NewSource(0))
|
|
var wg sync.WaitGroup
|
|
for i := range targetLoadCount {
|
|
wg.Add(1)
|
|
go func(i int) {
|
|
defer wg.Done()
|
|
reqs, resps := GenerateRequests()
|
|
for j := 0; j < 3; j++ {
|
|
if time.Now().Sub(started) > softTimeout {
|
|
slog.Info("exceeded soft timeout, winding down test")
|
|
return
|
|
}
|
|
k := r.Int() % len(reqs)
|
|
reqs[k].Model = chosenModels[i]
|
|
slog.Info("Starting", "model", reqs[k].Model, "iteration", j, "request", reqs[k].Prompt)
|
|
DoGenerate(ctx, t, client, reqs[k], resps[k],
|
|
120*time.Second, // Be extra patient for the model to load initially
|
|
10*time.Second, // Once results start streaming, fail if they stall
|
|
)
|
|
}
|
|
}(i)
|
|
}
|
|
go func() {
|
|
for {
|
|
time.Sleep(10 * time.Second)
|
|
select {
|
|
case <-ctx.Done():
|
|
return
|
|
default:
|
|
models, err := client.ListRunning(ctx)
|
|
if err != nil {
|
|
slog.Warn("failed to list running models", "error", err)
|
|
continue
|
|
}
|
|
for _, m := range models.Models {
|
|
var procStr string
|
|
switch {
|
|
case m.SizeVRAM == 0:
|
|
procStr = "100% CPU"
|
|
case m.SizeVRAM == m.Size:
|
|
procStr = "100% GPU"
|
|
case m.SizeVRAM > m.Size || m.Size == 0:
|
|
procStr = "Unknown"
|
|
default:
|
|
sizeCPU := m.Size - m.SizeVRAM
|
|
cpuPercent := math.Round(float64(sizeCPU) / float64(m.Size) * 100)
|
|
procStr = fmt.Sprintf("%d%%/%d%%", int(cpuPercent), int(100-cpuPercent))
|
|
}
|
|
|
|
slog.Info("loaded model snapshot", "model", m.Name, "CPU/GPU", procStr, "expires", format.HumanTime(m.ExpiresAt, "Never"))
|
|
}
|
|
}
|
|
}
|
|
}()
|
|
wg.Wait()
|
|
}
|