mirror of
https://github.com/jmorganca/ollama
synced 2025-10-06 00:32:49 +02:00
auth: fix problems with the ollama keypairs (#12373)
* auth: fix problems with the ollama keypairs This change adds several fixes including: - reading in the pubkey files correctly - fixing the push unit test to create a keypair file in a temp directory - not return 500 errors for normal status error
This commit is contained in:
@@ -45,6 +45,12 @@ func checkError(resp *http.Response, body []byte) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
if resp.StatusCode == http.StatusUnauthorized {
|
||||
authError := AuthorizationError{StatusCode: resp.StatusCode}
|
||||
json.Unmarshal(body, &authError)
|
||||
return authError
|
||||
}
|
||||
|
||||
apiError := StatusError{StatusCode: resp.StatusCode}
|
||||
|
||||
err := json.Unmarshal(body, &apiError)
|
||||
@@ -214,7 +220,8 @@ func (c *Client) stream(ctx context.Context, method, path string, data any, fn f
|
||||
scanner.Buffer(scanBuf, maxBufferSize)
|
||||
for scanner.Scan() {
|
||||
var errorResponse struct {
|
||||
Error string `json:"error,omitempty"`
|
||||
Error string `json:"error,omitempty"`
|
||||
SigninURL string `json:"signin_url,omitempty"`
|
||||
}
|
||||
|
||||
bts := scanner.Bytes()
|
||||
@@ -223,14 +230,10 @@ func (c *Client) stream(ctx context.Context, method, path string, data any, fn f
|
||||
}
|
||||
|
||||
if response.StatusCode == http.StatusUnauthorized {
|
||||
pubKey, pkErr := auth.GetPublicKey()
|
||||
if pkErr != nil {
|
||||
return pkErr
|
||||
}
|
||||
return AuthorizationError{
|
||||
StatusCode: response.StatusCode,
|
||||
Status: response.Status,
|
||||
PublicKey: pubKey,
|
||||
SigninURL: errorResponse.SigninURL,
|
||||
}
|
||||
} else if response.StatusCode >= http.StatusBadRequest {
|
||||
return StatusError{
|
||||
@@ -439,8 +442,13 @@ func (c *Client) Version(ctx context.Context) (string, error) {
|
||||
return version.Version, nil
|
||||
}
|
||||
|
||||
// Signout will disconnect an ollama instance from ollama.com
|
||||
func (c *Client) Signout(ctx context.Context, encodedKey string) error {
|
||||
// Signout will signout a client for a local ollama server.
|
||||
func (c *Client) Signout(ctx context.Context) error {
|
||||
return c.do(ctx, http.MethodPost, "/api/signout", nil, nil)
|
||||
}
|
||||
|
||||
// Disconnect will disconnect an ollama instance from ollama.com.
|
||||
func (c *Client) Disconnect(ctx context.Context, encodedKey string) error {
|
||||
return c.do(ctx, http.MethodDelete, fmt.Sprintf("/api/user/keys/%s", encodedKey), nil, nil)
|
||||
}
|
||||
|
||||
|
@@ -41,7 +41,7 @@ func (e StatusError) Error() string {
|
||||
type AuthorizationError struct {
|
||||
StatusCode int
|
||||
Status string
|
||||
PublicKey string `json:"public_key"`
|
||||
SigninURL string `json:"signin_url"`
|
||||
}
|
||||
|
||||
func (e AuthorizationError) Error() string {
|
||||
|
40
auth/auth.go
40
auth/auth.go
@@ -18,46 +18,13 @@ import (
|
||||
|
||||
const defaultPrivateKey = "id_ed25519"
|
||||
|
||||
func keyPath() (string, error) {
|
||||
fileIsReadable := func(fp string) bool {
|
||||
info, err := os.Stat(fp)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check that it's a regular file, not a directory or other file type
|
||||
if !info.Mode().IsRegular() {
|
||||
return false
|
||||
}
|
||||
|
||||
// Try to open it to check readability
|
||||
file, err := os.Open(fp)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
file.Close()
|
||||
return true
|
||||
}
|
||||
|
||||
systemPath := filepath.Join("/usr/share/ollama/.ollama", defaultPrivateKey)
|
||||
if fileIsReadable(systemPath) {
|
||||
return systemPath, nil
|
||||
}
|
||||
|
||||
func GetPublicKey() (string, error) {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return filepath.Join(home, ".ollama", defaultPrivateKey), nil
|
||||
}
|
||||
|
||||
func GetPublicKey() (string, error) {
|
||||
keyPath, err := keyPath()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
keyPath := filepath.Join(home, ".ollama", defaultPrivateKey)
|
||||
privateKeyFile, err := os.ReadFile(keyPath)
|
||||
if err != nil {
|
||||
slog.Info(fmt.Sprintf("Failed to load private key: %v", err))
|
||||
@@ -84,11 +51,12 @@ func NewNonce(r io.Reader, length int) (string, error) {
|
||||
}
|
||||
|
||||
func Sign(ctx context.Context, bts []byte) (string, error) {
|
||||
keyPath, err := keyPath()
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
keyPath := filepath.Join(home, ".ollama", defaultPrivateKey)
|
||||
privateKeyFile, err := os.ReadFile(keyPath)
|
||||
if err != nil {
|
||||
slog.Info(fmt.Sprintf("Failed to load private key: %v", err))
|
||||
|
56
cmd/cmd.go
56
cmd/cmd.go
@@ -5,7 +5,6 @@ import (
|
||||
"context"
|
||||
"crypto/ed25519"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"encoding/pem"
|
||||
"errors"
|
||||
@@ -15,7 +14,6 @@ import (
|
||||
"math"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"os/signal"
|
||||
"path/filepath"
|
||||
@@ -37,7 +35,6 @@ import (
|
||||
"golang.org/x/term"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/auth"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/format"
|
||||
"github.com/ollama/ollama/parser"
|
||||
@@ -50,7 +47,7 @@ import (
|
||||
"github.com/ollama/ollama/version"
|
||||
)
|
||||
|
||||
const ConnectInstructions = "To sign in, navigate to:\n https://ollama.com/connect?name=%s&key=%s\n\n"
|
||||
const ConnectInstructions = "To sign in, navigate to:\n %s\n\n"
|
||||
|
||||
// ensureThinkingSupport emits a warning if the model does not advertise thinking support
|
||||
func ensureThinkingSupport(ctx context.Context, client *api.Client, name string) {
|
||||
@@ -452,16 +449,10 @@ func RunHandler(cmd *cobra.Command, args []string) error {
|
||||
if err := loadOrUnloadModel(cmd, &opts); err != nil {
|
||||
var sErr api.AuthorizationError
|
||||
if errors.As(err, &sErr) && sErr.StatusCode == http.StatusUnauthorized {
|
||||
pubKey, pkErr := auth.GetPublicKey()
|
||||
if pkErr != nil {
|
||||
return pkErr
|
||||
}
|
||||
// the server and the client both have the same public key
|
||||
if pubKey == sErr.PublicKey {
|
||||
h, _ := os.Hostname()
|
||||
encKey := base64.RawURLEncoding.EncodeToString([]byte(pubKey))
|
||||
fmt.Printf("You need to be signed in to Ollama to run Cloud models.\n\n")
|
||||
fmt.Printf(ConnectInstructions, url.PathEscape(h), encKey)
|
||||
fmt.Printf("You need to be signed in to Ollama to run Cloud models.\n\n")
|
||||
|
||||
if sErr.SigninURL != "" {
|
||||
fmt.Printf(ConnectInstructions, sErr.SigninURL)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -493,6 +484,16 @@ func SigninHandler(cmd *cobra.Command, args []string) error {
|
||||
|
||||
user, err := client.Whoami(cmd.Context())
|
||||
if err != nil {
|
||||
var aErr api.AuthorizationError
|
||||
if errors.As(err, &aErr) && aErr.StatusCode == http.StatusUnauthorized {
|
||||
fmt.Println("You need to be signed in to Ollama to run Cloud models.")
|
||||
fmt.Println()
|
||||
|
||||
if aErr.SigninURL != "" {
|
||||
fmt.Printf(ConnectInstructions, aErr.SigninURL)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -502,34 +503,27 @@ func SigninHandler(cmd *cobra.Command, args []string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
pubKey, pkErr := auth.GetPublicKey()
|
||||
if pkErr != nil {
|
||||
return pkErr
|
||||
}
|
||||
encKey := base64.RawURLEncoding.EncodeToString([]byte(pubKey))
|
||||
|
||||
h, _ := os.Hostname()
|
||||
fmt.Printf(ConnectInstructions, url.PathEscape(h), encKey)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func SignoutHandler(cmd *cobra.Command, args []string) error {
|
||||
pubKey, pkErr := auth.GetPublicKey()
|
||||
if pkErr != nil {
|
||||
return pkErr
|
||||
}
|
||||
encKey := base64.RawURLEncoding.EncodeToString([]byte(pubKey))
|
||||
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = client.Signout(cmd.Context(), encKey)
|
||||
err = client.Signout(cmd.Context())
|
||||
if err != nil {
|
||||
return err
|
||||
var aErr api.AuthorizationError
|
||||
if errors.As(err, &aErr) && aErr.StatusCode == http.StatusUnauthorized {
|
||||
fmt.Println("You are not signed in to ollama.com")
|
||||
fmt.Println()
|
||||
return nil
|
||||
} else {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
fmt.Println("You have signed out of ollama.com")
|
||||
fmt.Println()
|
||||
return nil
|
||||
|
@@ -525,6 +525,9 @@ func TestPushHandler(t *testing.T) {
|
||||
defer mockServer.Close()
|
||||
|
||||
t.Setenv("OLLAMA_HOST", mockServer.URL)
|
||||
tmpDir := t.TempDir()
|
||||
t.Setenv("HOME", tmpDir)
|
||||
t.Setenv("USERPROFILE", tmpDir)
|
||||
initializeKeypair()
|
||||
|
||||
cmd := &cobra.Command{}
|
||||
|
@@ -4,6 +4,7 @@ import (
|
||||
"bytes"
|
||||
"cmp"
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
@@ -48,6 +49,8 @@ import (
|
||||
"github.com/ollama/ollama/version"
|
||||
)
|
||||
|
||||
const signinURLStr = "https://ollama.com/connect?name=%s&key=%s"
|
||||
|
||||
func shouldUseHarmony(model *Model) bool {
|
||||
if slices.Contains([]string{"gptoss", "gpt-oss"}, model.Config.ModelFamily) {
|
||||
// heuristic to check whether the template expects to be parsed via harmony:
|
||||
@@ -150,6 +153,17 @@ func (s *Server) scheduleRunner(ctx context.Context, name string, caps []model.C
|
||||
return runner.llama, model, &opts, nil
|
||||
}
|
||||
|
||||
func signinURL() (string, error) {
|
||||
pubKey, err := auth.GetPublicKey()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
encKey := base64.RawURLEncoding.EncodeToString([]byte(pubKey))
|
||||
h, _ := os.Hostname()
|
||||
return fmt.Sprintf(signinURLStr, url.PathEscape(h), encKey), nil
|
||||
}
|
||||
|
||||
func (s *Server) GenerateHandler(c *gin.Context) {
|
||||
checkpointStart := time.Now()
|
||||
var req api.GenerateRequest
|
||||
@@ -250,18 +264,21 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||
client := api.NewClient(remoteURL, http.DefaultClient)
|
||||
err = client.Generate(c, &req, fn)
|
||||
if err != nil {
|
||||
var sErr api.AuthorizationError
|
||||
if errors.As(err, &sErr) && sErr.StatusCode == http.StatusUnauthorized {
|
||||
pk, pkErr := auth.GetPublicKey()
|
||||
if pkErr != nil {
|
||||
slog.Error("couldn't get public key", "error", pkErr)
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "error getting public key"})
|
||||
var authError api.AuthorizationError
|
||||
if errors.As(err, &authError) {
|
||||
sURL, sErr := signinURL()
|
||||
if sErr != nil {
|
||||
slog.Error(sErr.Error())
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "error getting authorization details"})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusUnauthorized, gin.H{
|
||||
"error": "unauthorized",
|
||||
"public_key": pk,
|
||||
})
|
||||
|
||||
c.JSON(authError.StatusCode, gin.H{"error": "unauthorized", "signin_url": sURL})
|
||||
return
|
||||
}
|
||||
var apiError api.StatusError
|
||||
if errors.As(err, &apiError) {
|
||||
c.JSON(apiError.StatusCode, apiError)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
@@ -1412,9 +1429,12 @@ func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) {
|
||||
r.POST("/api/show", s.ShowHandler)
|
||||
r.DELETE("/api/delete", s.DeleteHandler)
|
||||
|
||||
r.DELETE("/api/user/keys/:encodedKey", s.SignoutHandler)
|
||||
r.POST("/api/me", s.WhoamiHandler)
|
||||
|
||||
r.POST("/api/signout", s.SignoutHandler)
|
||||
// deprecated
|
||||
r.DELETE("/api/user/keys/:encodedKey", s.SignoutHandler)
|
||||
|
||||
// Create
|
||||
r.POST("/api/create", s.CreateHandler)
|
||||
r.POST("/api/blobs/:digest", s.CreateBlobHandler)
|
||||
@@ -1625,11 +1645,32 @@ func (s *Server) WhoamiHandler(c *gin.Context) {
|
||||
if err != nil {
|
||||
slog.Error(err.Error())
|
||||
}
|
||||
|
||||
// user isn't signed in
|
||||
if user != nil && user.Name == "" {
|
||||
sURL, sErr := signinURL()
|
||||
if sErr != nil {
|
||||
slog.Error(sErr.Error())
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "error getting authorization details"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "unauthorized", "signin_url": sURL})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, user)
|
||||
}
|
||||
|
||||
func (s *Server) SignoutHandler(c *gin.Context) {
|
||||
encodedKey := c.Param("encodedKey")
|
||||
pubKey, err := auth.GetPublicKey()
|
||||
if err != nil {
|
||||
slog.Error("couldn't get public key", "error", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "there was an error signing out"})
|
||||
return
|
||||
}
|
||||
|
||||
encKey := base64.RawURLEncoding.EncodeToString([]byte(pubKey))
|
||||
|
||||
// todo allow other hosts
|
||||
u, err := url.Parse("https://ollama.com")
|
||||
@@ -1640,11 +1681,11 @@ func (s *Server) SignoutHandler(c *gin.Context) {
|
||||
}
|
||||
|
||||
client := api.NewClient(u, http.DefaultClient)
|
||||
err = client.Signout(c, encodedKey)
|
||||
err = client.Disconnect(c, encKey)
|
||||
if err != nil {
|
||||
slog.Error(err.Error())
|
||||
if strings.Contains(err.Error(), "page not found") || strings.Contains(err.Error(), "invalid credentials") {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "you are not currently signed in"})
|
||||
var authError api.AuthorizationError
|
||||
if errors.As(err, &authError) {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "you are not currently signed in"})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "there was an error signing out"})
|
||||
@@ -1802,18 +1843,21 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
||||
client := api.NewClient(remoteURL, http.DefaultClient)
|
||||
err = client.Chat(c, &req, fn)
|
||||
if err != nil {
|
||||
var sErr api.AuthorizationError
|
||||
if errors.As(err, &sErr) && sErr.StatusCode == http.StatusUnauthorized {
|
||||
pk, pkErr := auth.GetPublicKey()
|
||||
if pkErr != nil {
|
||||
slog.Error("couldn't get public key", "error", pkErr)
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "error getting public key"})
|
||||
var authError api.AuthorizationError
|
||||
if errors.As(err, &authError) {
|
||||
sURL, sErr := signinURL()
|
||||
if sErr != nil {
|
||||
slog.Error(sErr.Error())
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "error getting authorization details"})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusUnauthorized, gin.H{
|
||||
"error": "unauthorized",
|
||||
"public_key": pk,
|
||||
})
|
||||
|
||||
c.JSON(authError.StatusCode, gin.H{"error": "unauthorized", "signin_url": sURL})
|
||||
return
|
||||
}
|
||||
var apiError api.StatusError
|
||||
if errors.As(err, &apiError) {
|
||||
c.JSON(apiError.StatusCode, apiError)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
|
Reference in New Issue
Block a user