auth: fix problems with the ollama keypairs (#12373)

* auth: fix problems with the ollama keypairs

This change adds several fixes including:
  - reading in the pubkey files correctly
  - fixing the push unit test to create a keypair file in a temp directory
  - not return 500 errors for normal status error
This commit is contained in:
Patrick Devine
2025-09-22 23:20:20 -07:00
committed by GitHub
parent 41efdd4048
commit 64883e3c4c
6 changed files with 119 additions and 102 deletions

View File

@@ -45,6 +45,12 @@ func checkError(resp *http.Response, body []byte) error {
return nil
}
if resp.StatusCode == http.StatusUnauthorized {
authError := AuthorizationError{StatusCode: resp.StatusCode}
json.Unmarshal(body, &authError)
return authError
}
apiError := StatusError{StatusCode: resp.StatusCode}
err := json.Unmarshal(body, &apiError)
@@ -214,7 +220,8 @@ func (c *Client) stream(ctx context.Context, method, path string, data any, fn f
scanner.Buffer(scanBuf, maxBufferSize)
for scanner.Scan() {
var errorResponse struct {
Error string `json:"error,omitempty"`
Error string `json:"error,omitempty"`
SigninURL string `json:"signin_url,omitempty"`
}
bts := scanner.Bytes()
@@ -223,14 +230,10 @@ func (c *Client) stream(ctx context.Context, method, path string, data any, fn f
}
if response.StatusCode == http.StatusUnauthorized {
pubKey, pkErr := auth.GetPublicKey()
if pkErr != nil {
return pkErr
}
return AuthorizationError{
StatusCode: response.StatusCode,
Status: response.Status,
PublicKey: pubKey,
SigninURL: errorResponse.SigninURL,
}
} else if response.StatusCode >= http.StatusBadRequest {
return StatusError{
@@ -439,8 +442,13 @@ func (c *Client) Version(ctx context.Context) (string, error) {
return version.Version, nil
}
// Signout will disconnect an ollama instance from ollama.com
func (c *Client) Signout(ctx context.Context, encodedKey string) error {
// Signout will signout a client for a local ollama server.
func (c *Client) Signout(ctx context.Context) error {
return c.do(ctx, http.MethodPost, "/api/signout", nil, nil)
}
// Disconnect will disconnect an ollama instance from ollama.com.
func (c *Client) Disconnect(ctx context.Context, encodedKey string) error {
return c.do(ctx, http.MethodDelete, fmt.Sprintf("/api/user/keys/%s", encodedKey), nil, nil)
}

View File

@@ -41,7 +41,7 @@ func (e StatusError) Error() string {
type AuthorizationError struct {
StatusCode int
Status string
PublicKey string `json:"public_key"`
SigninURL string `json:"signin_url"`
}
func (e AuthorizationError) Error() string {

View File

@@ -18,46 +18,13 @@ import (
const defaultPrivateKey = "id_ed25519"
func keyPath() (string, error) {
fileIsReadable := func(fp string) bool {
info, err := os.Stat(fp)
if err != nil {
return false
}
// Check that it's a regular file, not a directory or other file type
if !info.Mode().IsRegular() {
return false
}
// Try to open it to check readability
file, err := os.Open(fp)
if err != nil {
return false
}
file.Close()
return true
}
systemPath := filepath.Join("/usr/share/ollama/.ollama", defaultPrivateKey)
if fileIsReadable(systemPath) {
return systemPath, nil
}
func GetPublicKey() (string, error) {
home, err := os.UserHomeDir()
if err != nil {
return "", err
}
return filepath.Join(home, ".ollama", defaultPrivateKey), nil
}
func GetPublicKey() (string, error) {
keyPath, err := keyPath()
if err != nil {
return "", err
}
keyPath := filepath.Join(home, ".ollama", defaultPrivateKey)
privateKeyFile, err := os.ReadFile(keyPath)
if err != nil {
slog.Info(fmt.Sprintf("Failed to load private key: %v", err))
@@ -84,11 +51,12 @@ func NewNonce(r io.Reader, length int) (string, error) {
}
func Sign(ctx context.Context, bts []byte) (string, error) {
keyPath, err := keyPath()
home, err := os.UserHomeDir()
if err != nil {
return "", err
}
keyPath := filepath.Join(home, ".ollama", defaultPrivateKey)
privateKeyFile, err := os.ReadFile(keyPath)
if err != nil {
slog.Info(fmt.Sprintf("Failed to load private key: %v", err))

View File

@@ -5,7 +5,6 @@ import (
"context"
"crypto/ed25519"
"crypto/rand"
"encoding/base64"
"encoding/json"
"encoding/pem"
"errors"
@@ -15,7 +14,6 @@ import (
"math"
"net"
"net/http"
"net/url"
"os"
"os/signal"
"path/filepath"
@@ -37,7 +35,6 @@ import (
"golang.org/x/term"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/auth"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/format"
"github.com/ollama/ollama/parser"
@@ -50,7 +47,7 @@ import (
"github.com/ollama/ollama/version"
)
const ConnectInstructions = "To sign in, navigate to:\n https://ollama.com/connect?name=%s&key=%s\n\n"
const ConnectInstructions = "To sign in, navigate to:\n %s\n\n"
// ensureThinkingSupport emits a warning if the model does not advertise thinking support
func ensureThinkingSupport(ctx context.Context, client *api.Client, name string) {
@@ -452,16 +449,10 @@ func RunHandler(cmd *cobra.Command, args []string) error {
if err := loadOrUnloadModel(cmd, &opts); err != nil {
var sErr api.AuthorizationError
if errors.As(err, &sErr) && sErr.StatusCode == http.StatusUnauthorized {
pubKey, pkErr := auth.GetPublicKey()
if pkErr != nil {
return pkErr
}
// the server and the client both have the same public key
if pubKey == sErr.PublicKey {
h, _ := os.Hostname()
encKey := base64.RawURLEncoding.EncodeToString([]byte(pubKey))
fmt.Printf("You need to be signed in to Ollama to run Cloud models.\n\n")
fmt.Printf(ConnectInstructions, url.PathEscape(h), encKey)
fmt.Printf("You need to be signed in to Ollama to run Cloud models.\n\n")
if sErr.SigninURL != "" {
fmt.Printf(ConnectInstructions, sErr.SigninURL)
}
return nil
}
@@ -493,6 +484,16 @@ func SigninHandler(cmd *cobra.Command, args []string) error {
user, err := client.Whoami(cmd.Context())
if err != nil {
var aErr api.AuthorizationError
if errors.As(err, &aErr) && aErr.StatusCode == http.StatusUnauthorized {
fmt.Println("You need to be signed in to Ollama to run Cloud models.")
fmt.Println()
if aErr.SigninURL != "" {
fmt.Printf(ConnectInstructions, aErr.SigninURL)
}
return nil
}
return err
}
@@ -502,34 +503,27 @@ func SigninHandler(cmd *cobra.Command, args []string) error {
return nil
}
pubKey, pkErr := auth.GetPublicKey()
if pkErr != nil {
return pkErr
}
encKey := base64.RawURLEncoding.EncodeToString([]byte(pubKey))
h, _ := os.Hostname()
fmt.Printf(ConnectInstructions, url.PathEscape(h), encKey)
return nil
}
func SignoutHandler(cmd *cobra.Command, args []string) error {
pubKey, pkErr := auth.GetPublicKey()
if pkErr != nil {
return pkErr
}
encKey := base64.RawURLEncoding.EncodeToString([]byte(pubKey))
client, err := api.ClientFromEnvironment()
if err != nil {
return err
}
err = client.Signout(cmd.Context(), encKey)
err = client.Signout(cmd.Context())
if err != nil {
return err
var aErr api.AuthorizationError
if errors.As(err, &aErr) && aErr.StatusCode == http.StatusUnauthorized {
fmt.Println("You are not signed in to ollama.com")
fmt.Println()
return nil
} else {
return err
}
}
fmt.Println("You have signed out of ollama.com")
fmt.Println()
return nil

View File

@@ -525,6 +525,9 @@ func TestPushHandler(t *testing.T) {
defer mockServer.Close()
t.Setenv("OLLAMA_HOST", mockServer.URL)
tmpDir := t.TempDir()
t.Setenv("HOME", tmpDir)
t.Setenv("USERPROFILE", tmpDir)
initializeKeypair()
cmd := &cobra.Command{}

View File

@@ -4,6 +4,7 @@ import (
"bytes"
"cmp"
"context"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
@@ -48,6 +49,8 @@ import (
"github.com/ollama/ollama/version"
)
const signinURLStr = "https://ollama.com/connect?name=%s&key=%s"
func shouldUseHarmony(model *Model) bool {
if slices.Contains([]string{"gptoss", "gpt-oss"}, model.Config.ModelFamily) {
// heuristic to check whether the template expects to be parsed via harmony:
@@ -150,6 +153,17 @@ func (s *Server) scheduleRunner(ctx context.Context, name string, caps []model.C
return runner.llama, model, &opts, nil
}
func signinURL() (string, error) {
pubKey, err := auth.GetPublicKey()
if err != nil {
return "", err
}
encKey := base64.RawURLEncoding.EncodeToString([]byte(pubKey))
h, _ := os.Hostname()
return fmt.Sprintf(signinURLStr, url.PathEscape(h), encKey), nil
}
func (s *Server) GenerateHandler(c *gin.Context) {
checkpointStart := time.Now()
var req api.GenerateRequest
@@ -250,18 +264,21 @@ func (s *Server) GenerateHandler(c *gin.Context) {
client := api.NewClient(remoteURL, http.DefaultClient)
err = client.Generate(c, &req, fn)
if err != nil {
var sErr api.AuthorizationError
if errors.As(err, &sErr) && sErr.StatusCode == http.StatusUnauthorized {
pk, pkErr := auth.GetPublicKey()
if pkErr != nil {
slog.Error("couldn't get public key", "error", pkErr)
c.JSON(http.StatusUnauthorized, gin.H{"error": "error getting public key"})
var authError api.AuthorizationError
if errors.As(err, &authError) {
sURL, sErr := signinURL()
if sErr != nil {
slog.Error(sErr.Error())
c.JSON(http.StatusInternalServerError, gin.H{"error": "error getting authorization details"})
return
}
c.JSON(http.StatusUnauthorized, gin.H{
"error": "unauthorized",
"public_key": pk,
})
c.JSON(authError.StatusCode, gin.H{"error": "unauthorized", "signin_url": sURL})
return
}
var apiError api.StatusError
if errors.As(err, &apiError) {
c.JSON(apiError.StatusCode, apiError)
return
}
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
@@ -1412,9 +1429,12 @@ func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) {
r.POST("/api/show", s.ShowHandler)
r.DELETE("/api/delete", s.DeleteHandler)
r.DELETE("/api/user/keys/:encodedKey", s.SignoutHandler)
r.POST("/api/me", s.WhoamiHandler)
r.POST("/api/signout", s.SignoutHandler)
// deprecated
r.DELETE("/api/user/keys/:encodedKey", s.SignoutHandler)
// Create
r.POST("/api/create", s.CreateHandler)
r.POST("/api/blobs/:digest", s.CreateBlobHandler)
@@ -1625,11 +1645,32 @@ func (s *Server) WhoamiHandler(c *gin.Context) {
if err != nil {
slog.Error(err.Error())
}
// user isn't signed in
if user != nil && user.Name == "" {
sURL, sErr := signinURL()
if sErr != nil {
slog.Error(sErr.Error())
c.JSON(http.StatusInternalServerError, gin.H{"error": "error getting authorization details"})
return
}
c.JSON(http.StatusUnauthorized, gin.H{"error": "unauthorized", "signin_url": sURL})
return
}
c.JSON(http.StatusOK, user)
}
func (s *Server) SignoutHandler(c *gin.Context) {
encodedKey := c.Param("encodedKey")
pubKey, err := auth.GetPublicKey()
if err != nil {
slog.Error("couldn't get public key", "error", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "there was an error signing out"})
return
}
encKey := base64.RawURLEncoding.EncodeToString([]byte(pubKey))
// todo allow other hosts
u, err := url.Parse("https://ollama.com")
@@ -1640,11 +1681,11 @@ func (s *Server) SignoutHandler(c *gin.Context) {
}
client := api.NewClient(u, http.DefaultClient)
err = client.Signout(c, encodedKey)
err = client.Disconnect(c, encKey)
if err != nil {
slog.Error(err.Error())
if strings.Contains(err.Error(), "page not found") || strings.Contains(err.Error(), "invalid credentials") {
c.JSON(http.StatusNotFound, gin.H{"error": "you are not currently signed in"})
var authError api.AuthorizationError
if errors.As(err, &authError) {
c.JSON(http.StatusUnauthorized, gin.H{"error": "you are not currently signed in"})
return
}
c.JSON(http.StatusInternalServerError, gin.H{"error": "there was an error signing out"})
@@ -1802,18 +1843,21 @@ func (s *Server) ChatHandler(c *gin.Context) {
client := api.NewClient(remoteURL, http.DefaultClient)
err = client.Chat(c, &req, fn)
if err != nil {
var sErr api.AuthorizationError
if errors.As(err, &sErr) && sErr.StatusCode == http.StatusUnauthorized {
pk, pkErr := auth.GetPublicKey()
if pkErr != nil {
slog.Error("couldn't get public key", "error", pkErr)
c.JSON(http.StatusUnauthorized, gin.H{"error": "error getting public key"})
var authError api.AuthorizationError
if errors.As(err, &authError) {
sURL, sErr := signinURL()
if sErr != nil {
slog.Error(sErr.Error())
c.JSON(http.StatusInternalServerError, gin.H{"error": "error getting authorization details"})
return
}
c.JSON(http.StatusUnauthorized, gin.H{
"error": "unauthorized",
"public_key": pk,
})
c.JSON(authError.StatusCode, gin.H{"error": "unauthorized", "signin_url": sURL})
return
}
var apiError api.StatusError
if errors.As(err, &apiError) {
c.JSON(apiError.StatusCode, apiError)
return
}
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})