Skip to content

Commit

Permalink
api: add _defaultApiClient for resue and style fix
Browse files Browse the repository at this point in the history
  • Loading branch information
alwqx committed May 3, 2024
1 parent e9ae607 commit 75c5080
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 11 deletions.
13 changes: 11 additions & 2 deletions api/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ func checkError(resp *http.Response, body []byte) error {
return apiError
}

// _defaultApiClient the default api client for reuse
var _defaultApiClient *Client

// ClientFromEnvironment creates a new [Client] using configuration from the
// environment variable OLLAMA_HOST, which points to the network host and
// port on which the ollama service is listenting. The format of this variable
Expand All @@ -58,18 +61,24 @@ func checkError(resp *http.Response, body []byte) error {
// If the variable is not specified, a default ollama host and port will be
// used.
func ClientFromEnvironment() (*Client, error) {
if _defaultApiClient != nil {
return _defaultApiClient, nil
}

ollamaHost, err := GetOllamaHost()
if err != nil {
return nil, err
}

return &Client{
_defaultApiClient = &Client{
base: &url.URL{
Scheme: ollamaHost.Scheme,
Host: net.JoinHostPort(ollamaHost.Host, ollamaHost.Port),
},
http: http.DefaultClient,
}, nil
}

return _defaultApiClient, nil
}

type OllamaHost struct {
Expand Down
44 changes: 39 additions & 5 deletions api/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestClientFromEnvironment(t *testing.T) {
Expand All @@ -32,22 +33,55 @@ func TestClientFromEnvironment(t *testing.T) {
"trailing slash port": {value: "example.com:1234/", expect: "http://example.com:1234"},
}

// test ingnore _defaultApiClient
for k, v := range testCases {
t.Run(k, func(t *testing.T) {
t.Setenv("OLLAMA_HOST", v.value)
// set _defaultApiClient=nil for really new client in ClientFromEnvironment
_defaultApiClient = nil

t.Setenv("OLLAMA_HOST", v.value)
client, err := ClientFromEnvironment()
if err != v.err {
t.Fatalf("expected %s, got %s", v.err, err)
t.Fatalf("case %s: expected %s, got %s", k, v.err, err)
}

if client.base.String() != v.expect {
t.Fatalf("expected %s, got %s", v.expect, client.base.String())
t.Fatalf("case %s: expected %s, got %s", k, v.expect, client.base.String())
}
})
}

hostTestCases := map[string]*testCase{
// test with _defaultApiClient
for k, v := range testCases {
_defaultApiClient = nil
t.Setenv("OLLAMA_HOST", v.value)
client, err := ClientFromEnvironment()
if err != v.err {
t.Fatalf("case %s: expected %s, got %s", k, v.err, err)
}
if client.base.String() != v.expect {
t.Fatalf("case %s: expected %s, got %s", k, v.expect, client.base.String())
}
require.Equal(t, _defaultApiClient, client)

// call ClientFromEnvironment again, should return _defaultApiClient directly
client2, err := ClientFromEnvironment()
require.Nil(t, err)
require.NotNil(t, client2)
// address and fields of _defaultApiClient, client2 and client should equal
require.Equal(t, client, client2, fmt.Sprintf("case %s: expected %T, got %T", k, client, client2))
require.Equal(t, _defaultApiClient, client2, fmt.Sprintf("case %s: expected %T, got %T", k, _defaultApiClient, client2))
require.Equal(t, client.base.String(), client2.base.String(), fmt.Sprintf("case %s: expected %s, got %s", k, client.base.String(), client2.base.String()))
require.Equal(t, client.http, _defaultApiClient.http, fmt.Sprintf("case %s: expected %T, got %T", k, client.http, _defaultApiClient.http))
}
}

func TestGetOllamaHost(t *testing.T) {
hostTestCases := map[string]*struct {
value string
expect string
err error
}{
"empty": {value: "", expect: "127.0.0.1:11434"},
"only address": {value: "1.2.3.4", expect: "1.2.3.4:11434"},
"only port": {value: ":1234", expect: ":1234"},
Expand All @@ -73,7 +107,7 @@ func TestClientFromEnvironment(t *testing.T) {

oh, err := GetOllamaHost()
if err != v.err {
t.Fatalf("expected %s, got %s", v.err, err)
t.Fatalf("case %s: expected %s, got %s", k, v.err, err)
}

if err == nil {
Expand Down
7 changes: 3 additions & 4 deletions cmd/cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ import (
"time"

"github.com/containerd/console"

"github.com/olekukonko/tablewriter"
"github.com/spf13/cobra"
"golang.org/x/crypto/ssh"
Expand Down Expand Up @@ -312,7 +311,7 @@ func RunHandler(cmd *cobra.Command, args []string) error {
interactive := true

opts := runOptions{
Model: args[0],
Model: name,
WordWrap: os.Getenv("TERM") == "xterm-256color",
Options: map[string]interface{}{},
MultiModal: slices.Contains(show.Details.Families, "clip"),
Expand Down Expand Up @@ -982,7 +981,7 @@ func checkServerHeartbeat(cmd *cobra.Command, _ []string) error {
return nil
}

func versionHandler(cmd *cobra.Command, _ []string) {
func VersionHandler(cmd *cobra.Command, _ []string) {
client, err := api.ClientFromEnvironment()
if err != nil {
return
Expand Down Expand Up @@ -1028,7 +1027,7 @@ func NewCLI() *cobra.Command {
},
Run: func(cmd *cobra.Command, args []string) {
if version, _ := cmd.Flags().GetBool("version"); version {
versionHandler(cmd, args)
VersionHandler(cmd, args)
return
}

Expand Down

0 comments on commit 75c5080

Please sign in to comment.