Files
ollama/integration/concurrency_test.go
Daniel Hiltgen d6f7233a1c test: improve scheduler/concurrency stress tests (#11906)
* test: improve scheduler/concurrency stress tests

The scheduler test used to use approximate memory figures and would often
over or under shoot a systems capcity leading to flaky test results.
This should improve the reliability of this scenario by leveraging
ps output to determinie exactly how many models it takes to
trigger thrashing.

The concurrency test is also refined to target num_parallel + 1 and handle
timeouts better.

With these refinements, TestMultiModelConcurrency was redundant

* test: add parallel generate with history

TestGenerateWithHistory will help verify caching and context
are properly handled while making requests

* test: focus embed tests on embedding models

remove non-embedding models from the embedding tests
2025-08-15 14:37:54 -07:00

208 lines
5.7 KiB
Go

//go:build integration
package integration
import (
"context"
"fmt"
"log/slog"
"math"
"math/rand"
"os"
"strconv"
"sync"
"testing"
"time"
"github.com/stretchr/testify/require"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/format"
)
// Send multiple requests in parallel (concurrently) to a single model and ensure responses are expected
func TestConcurrentGenerate(t *testing.T) {
// Assumes all requests have the same model
req, resp := GenerateRequests()
numParallel := int(envconfig.NumParallel() + 1)
iterLimit := 3
softTimeout, hardTimeout := getTimeouts(t)
ctx, cancel := context.WithTimeout(context.Background(), hardTimeout)
defer cancel()
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
// Get the server running (if applicable) warm the model up with a single initial request
slog.Info("loading", "model", req[0].Model)
err := client.Generate(ctx,
&api.GenerateRequest{Model: req[0].Model, KeepAlive: &api.Duration{Duration: 10 * time.Second}},
func(response api.GenerateResponse) error { return nil },
)
if err != nil {
t.Fatalf("failed to load model %s: %s", req[0].Model, err)
}
var wg sync.WaitGroup
r := rand.New(rand.NewSource(0))
wg.Add(numParallel)
for i := range numParallel {
go func(i int) {
defer wg.Done()
for j := 0; j < iterLimit; j++ {
if time.Now().Sub(started) > softTimeout {
slog.Info("exceeded soft timeout, winding down test")
return
}
k := r.Int() % len(req)
slog.Info("Starting", "thread", i, "iter", j)
// On slower GPUs it can take a while to process the concurrent requests
// so we allow a much longer initial timeout
DoGenerate(ctx, t, client, req[k], resp[k], 120*time.Second, 20*time.Second)
}
}(i)
}
wg.Wait()
}
// Stress the scheduler and attempt to load more models than will fit to cause thrashing
// This test will always load at least 2 models even on CPU based systems
func TestMultiModelStress(t *testing.T) {
s := os.Getenv("OLLAMA_MAX_VRAM")
if s == "" {
s = "0"
}
maxVram, err := strconv.ParseUint(s, 10, 64)
if err != nil {
t.Fatal(err)
}
smallModels := []string{
"llama3.2:1b",
"qwen3:0.6b",
"gemma:2b",
"deepseek-r1:1.5b",
"starcoder2:3b",
}
mediumModels := []string{
"qwen3:8b",
"llama2",
"deepseek-r1:7b",
"mistral",
"dolphin-mistral",
"gemma:7b",
"codellama:7b",
}
var chosenModels []string
switch {
case maxVram < 10000*format.MebiByte:
slog.Info("selecting small models")
chosenModels = smallModels
default:
slog.Info("selecting medium models")
chosenModels = mediumModels
}
softTimeout, hardTimeout := getTimeouts(t)
ctx, cancel := context.WithTimeout(context.Background(), hardTimeout)
defer cancel()
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
// Make sure all the models are pulled before we get started
for _, model := range chosenModels {
require.NoError(t, PullIfMissing(ctx, client, model))
}
// Determine how many models we can load in parallel before we exceed VRAM
// The intent is to go 1 over what can fit so we force the scheduler to thrash
targetLoadCount := 0
slog.Info("Loading models to find how many can fit in VRAM before overflowing")
for i, model := range chosenModels {
req := &api.GenerateRequest{Model: model}
slog.Info("loading", "model", model)
err = client.Generate(ctx, req, func(response api.GenerateResponse) error { return nil })
if err != nil {
t.Fatalf("failed to load model %s: %s", model, err)
}
targetLoadCount++
if i > 0 {
models, err := client.ListRunning(ctx)
if err != nil {
t.Fatalf("failed to list running models: %s", err)
}
if len(models.Models) < targetLoadCount {
loaded := []string{}
for _, m := range models.Models {
loaded = append(loaded, m.Name)
}
slog.Info("found model load capacity", "target", targetLoadCount, "current", loaded, "chosen", chosenModels[:targetLoadCount])
break
}
}
}
if targetLoadCount == len(chosenModels) {
// TODO consider retrying the medium models
slog.Warn("all models being used without exceeding VRAM, set OLLAMA_MAX_VRAM so test can pick larger models")
}
r := rand.New(rand.NewSource(0))
var wg sync.WaitGroup
for i := range targetLoadCount {
wg.Add(1)
go func(i int) {
defer wg.Done()
reqs, resps := GenerateRequests()
for j := 0; j < 3; j++ {
if time.Now().Sub(started) > softTimeout {
slog.Info("exceeded soft timeout, winding down test")
return
}
k := r.Int() % len(reqs)
reqs[k].Model = chosenModels[i]
slog.Info("Starting", "model", reqs[k].Model, "iteration", j, "request", reqs[k].Prompt)
DoGenerate(ctx, t, client, reqs[k], resps[k],
120*time.Second, // Be extra patient for the model to load initially
10*time.Second, // Once results start streaming, fail if they stall
)
}
}(i)
}
go func() {
for {
time.Sleep(10 * time.Second)
select {
case <-ctx.Done():
return
default:
models, err := client.ListRunning(ctx)
if err != nil {
slog.Warn("failed to list running models", "error", err)
continue
}
for _, m := range models.Models {
var procStr string
switch {
case m.SizeVRAM == 0:
procStr = "100% CPU"
case m.SizeVRAM == m.Size:
procStr = "100% GPU"
case m.SizeVRAM > m.Size || m.Size == 0:
procStr = "Unknown"
default:
sizeCPU := m.Size - m.SizeVRAM
cpuPercent := math.Round(float64(sizeCPU) / float64(m.Size) * 100)
procStr = fmt.Sprintf("%d%%/%d%%", int(cpuPercent), int(100-cpuPercent))
}
slog.Info("loaded model snapshot", "model", m.Name, "CPU/GPU", procStr, "expires", format.HumanTime(m.ExpiresAt, "Never"))
}
}
}
}()
wg.Wait()
}