mirror of
https://github.com/jmorganca/ollama
synced 2025-10-06 00:32:49 +02:00
tests: reduce stress on CPU to 2 models (#12161)
* tests: reduce stress on CPU to 2 models This should avoid flakes due to systems getting overloaded with 3 (or more) models running concurrently * tests: allow slow systems to pass on timeout If a slow system is still streaming a response, and the response will pass validation, don't fail just because the system is slow. * test: unload embedding models more quickly
This commit is contained in:
@@ -121,6 +121,7 @@ func TestMultiModelStress(t *testing.T) {
|
||||
// The intent is to go 1 over what can fit so we force the scheduler to thrash
|
||||
targetLoadCount := 0
|
||||
slog.Info("Loading models to find how many can fit in VRAM before overflowing")
|
||||
chooseModels:
|
||||
for i, model := range chosenModels {
|
||||
req := &api.GenerateRequest{Model: model}
|
||||
slog.Info("loading", "model", model)
|
||||
@@ -142,6 +143,13 @@ func TestMultiModelStress(t *testing.T) {
|
||||
slog.Info("found model load capacity", "target", targetLoadCount, "current", loaded, "chosen", chosenModels[:targetLoadCount])
|
||||
break
|
||||
}
|
||||
// Effectively limit model count to 2 on CPU only systems to avoid thrashing and timeouts
|
||||
for _, m := range models.Models {
|
||||
if m.SizeVRAM == 0 {
|
||||
slog.Info("model running on CPU", "name", m.Name, "target", targetLoadCount, "chosen", chosenModels[:targetLoadCount])
|
||||
break chooseModels
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if targetLoadCount == len(chosenModels) {
|
||||
|
@@ -38,8 +38,9 @@ func TestAllMiniLMEmbeddings(t *testing.T) {
|
||||
defer cleanup()
|
||||
|
||||
req := api.EmbeddingRequest{
|
||||
Model: "all-minilm",
|
||||
Prompt: "why is the sky blue?",
|
||||
Model: "all-minilm",
|
||||
Prompt: "why is the sky blue?",
|
||||
KeepAlive: &api.Duration{Duration: 10 * time.Second},
|
||||
}
|
||||
|
||||
res, err := embeddingTestHelper(ctx, client, t, req)
|
||||
|
@@ -502,6 +502,22 @@ func DoGenerate(ctx context.Context, t *testing.T, client *api.Client, genReq ap
|
||||
done <- 0
|
||||
}()
|
||||
|
||||
var response string
|
||||
verify := func() {
|
||||
// Verify the response contains the expected data
|
||||
response = buf.String()
|
||||
atLeastOne := false
|
||||
for _, resp := range anyResp {
|
||||
if strings.Contains(strings.ToLower(response), resp) {
|
||||
atLeastOne = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !atLeastOne {
|
||||
t.Fatalf("%s: none of %v found in %s", genReq.Model, anyResp, response)
|
||||
}
|
||||
}
|
||||
|
||||
select {
|
||||
case <-stallTimer.C:
|
||||
if buf.Len() == 0 {
|
||||
@@ -517,21 +533,14 @@ func DoGenerate(ctx context.Context, t *testing.T, client *api.Client, genReq ap
|
||||
if genErr != nil {
|
||||
t.Fatalf("%s failed with %s request prompt %s", genErr, genReq.Model, genReq.Prompt)
|
||||
}
|
||||
// Verify the response contains the expected data
|
||||
response := buf.String()
|
||||
atLeastOne := false
|
||||
for _, resp := range anyResp {
|
||||
if strings.Contains(strings.ToLower(response), resp) {
|
||||
atLeastOne = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !atLeastOne {
|
||||
t.Fatalf("%s: none of %v found in %s", genReq.Model, anyResp, response)
|
||||
}
|
||||
verify()
|
||||
slog.Info("test pass", "model", genReq.Model, "prompt", genReq.Prompt, "contains", anyResp, "response", response)
|
||||
case <-ctx.Done():
|
||||
t.Error("outer test context done while waiting for generate")
|
||||
// On slow systems, we might timeout before some models finish rambling, so check what we have so far to see
|
||||
// if it's considered a pass - the stallTimer will detect hangs, but we want to consider slow systems a pass
|
||||
// if they are still generating valid responses
|
||||
slog.Warn("outer test context done while waiting for generate")
|
||||
verify()
|
||||
}
|
||||
return context
|
||||
}
|
||||
@@ -599,6 +608,22 @@ func DoChat(ctx context.Context, t *testing.T, client *api.Client, req api.ChatR
|
||||
done <- 0
|
||||
}()
|
||||
|
||||
var response string
|
||||
verify := func() {
|
||||
// Verify the response contains the expected data
|
||||
response = buf.String()
|
||||
atLeastOne := false
|
||||
for _, resp := range anyResp {
|
||||
if strings.Contains(strings.ToLower(response), resp) {
|
||||
atLeastOne = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !atLeastOne {
|
||||
t.Fatalf("%s: none of %v found in \"%s\" -- request was:%v", req.Model, anyResp, response, req.Messages)
|
||||
}
|
||||
}
|
||||
|
||||
select {
|
||||
case <-stallTimer.C:
|
||||
if buf.Len() == 0 {
|
||||
@@ -614,23 +639,14 @@ func DoChat(ctx context.Context, t *testing.T, client *api.Client, req api.ChatR
|
||||
if genErr != nil {
|
||||
t.Fatalf("%s failed with %s request prompt %v", genErr, req.Model, req.Messages)
|
||||
}
|
||||
|
||||
// Verify the response contains the expected data
|
||||
response := buf.String()
|
||||
atLeastOne := false
|
||||
for _, resp := range anyResp {
|
||||
if strings.Contains(strings.ToLower(response), resp) {
|
||||
atLeastOne = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !atLeastOne {
|
||||
t.Fatalf("%s: none of %v found in \"%s\" -- request was:%v", req.Model, anyResp, response, req.Messages)
|
||||
}
|
||||
|
||||
verify()
|
||||
slog.Info("test pass", "model", req.Model, "messages", req.Messages, "contains", anyResp, "response", response)
|
||||
case <-ctx.Done():
|
||||
t.Error("outer test context done while waiting for generate")
|
||||
// On slow systems, we might timeout before some models finish rambling, so check what we have so far to see
|
||||
// if it's considered a pass - the stallTimer will detect hangs, but we want to consider slow systems a pass
|
||||
// if they are still generating valid responses
|
||||
slog.Warn("outer test context done while waiting for chat")
|
||||
verify()
|
||||
}
|
||||
return &api.Message{Role: role, Content: buf.String()}
|
||||
}
|
||||
|
Reference in New Issue
Block a user