From 7096021f6a9cfaa51a08213e8149c9ddfec22686 Mon Sep 17 00:00:00 2001 From: Richard Wooding Date: Mon, 19 Jun 2023 10:50:10 +0200 Subject: [PATCH] LlamaCpp and Llama Go backends --- client/llamacpp/llamacppclient.go | 53 +++++++ client/llamagoremote/jobresponse.go | 16 +++ client/llamagoremote/jobstatusresponse.go | 5 + client/llamagoremote/llamagoclient.go | 161 ++++++++++++++++++++++ client/llamagoremote/promptrequest.go | 8 ++ client/llamagoremote/promptresponse.go | 13 ++ client/llamagoremote/protocol.go | 24 ++++ client/llamagoremote/status.go | 26 ++++ client/openai/openaiclient.go | 21 +-- client/vertexai/accesstoken.go | 5 +- client/vertexai/vertexaiclient.go | 9 +- cmd/root.go | 12 ++ cmd/run.go | 3 +- cmd/serve.go | 3 +- transport/debuggerroundtrip.go | 4 +- 15 files changed, 344 insertions(+), 19 deletions(-) create mode 100644 client/llamacpp/llamacppclient.go create mode 100644 client/llamagoremote/jobresponse.go create mode 100644 client/llamagoremote/jobstatusresponse.go create mode 100644 client/llamagoremote/llamagoclient.go create mode 100644 client/llamagoremote/promptrequest.go create mode 100644 client/llamagoremote/promptresponse.go create mode 100644 client/llamagoremote/protocol.go create mode 100644 client/llamagoremote/status.go diff --git a/client/llamacpp/llamacppclient.go b/client/llamacpp/llamacppclient.go new file mode 100644 index 0000000..0887079 --- /dev/null +++ b/client/llamacpp/llamacppclient.go @@ -0,0 +1,53 @@ +package llamacpp + +import ( + "fmt" + "github.com/spandigitial/codeassistant/client" + "github.com/spandigitial/codeassistant/model" + "os/exec" +) + +type Client struct { + binaryPath string + modelPath string + promptContextSize int + extraArguments []string +} + +type Option func(client *Client) + +func New(binaryPath string, modelPath string, promptContextSize int, options ...Option) *Client { + c := &Client{ + binaryPath: binaryPath, + modelPath: modelPath, + promptContextSize: promptContextSize, + } + + for _, option := range options { + option(c) + } + + return c +} + +func WithExtraArguments(arguments ...string) Option { + return func(client *Client) { + client.extraArguments = arguments + } +} + +func (c *Client) Models(models chan<- client.LanguageModel) error { + close(models) + return nil +} + +func (c *Client) Completion(ci *model.CommandInstance, messageParts chan<- client.MessagePart) error { + args := append([]string{"-m", c.modelPath, "-n", fmt.Sprintf("%d", c.promptContextSize)}, c.extraArguments...) + out, err := exec.Command(c.binaryPath, args...).Output() + if err != nil { + return err + } + messageParts <- client.MessagePart{Delta: "", Type: "Start"} + messageParts <- client.MessagePart{Delta: string(out), Type: "Part"} + messageParts <- client.MessagePart{Delta: "", Type: "Done"} +} diff --git a/client/llamagoremote/jobresponse.go b/client/llamagoremote/jobresponse.go new file mode 100644 index 0000000..bb76597 --- /dev/null +++ b/client/llamagoremote/jobresponse.go @@ -0,0 +1,16 @@ +package llamagoremote + +import ( + "github.com/google/uuid" + "time" +) + +type jobResponse struct { + ID uuid.UUID `json:"id"` + Prompt string `json:"prompt"` + Output string `json:"output"` + Created time.Time `json:"created"` + Started time.Time `json:"started"` + Model string `json:"model"` + status status `json:"status"` +} diff --git a/client/llamagoremote/jobstatusresponse.go b/client/llamagoremote/jobstatusresponse.go new file mode 100644 index 0000000..ab38b83 --- /dev/null +++ b/client/llamagoremote/jobstatusresponse.go @@ -0,0 +1,5 @@ +package llamagoremote + +type jobStatusResponse struct { + Status status `json:"status"` +} diff --git a/client/llamagoremote/llamagoclient.go b/client/llamagoremote/llamagoclient.go new file mode 100644 index 0000000..87826f8 --- /dev/null +++ b/client/llamagoremote/llamagoclient.go @@ -0,0 +1,161 @@ +package llamagoremote + +import ( + "bytes" + "encoding/json" + "fmt" + "github.com/google/uuid" + "github.com/spandigitial/codeassistant/client" + "github.com/spandigitial/codeassistant/client/debugger" + "github.com/spandigitial/codeassistant/model" + "github.com/spandigitial/codeassistant/transport" + "io" + "net/http" + "time" +) + +type Client struct { + protocol Protocol + host string + port int + debugger *debugger.Debugger + httpClient *http.Client + userAgent *string + pollDuration time.Duration +} + +type Option func(client *Client) + +func New(protocol Protocol, host string, port int, pollDuration time.Duration, debugger *debugger.Debugger, options ...Option) *Client { + c := &Client{ + protocol: protocol, + host: host, + port: port, + debugger: debugger, + pollDuration: pollDuration, + } + + for _, option := range options { + option(c) + } + + if c.httpClient == nil { + c.httpClient = http.DefaultClient + } + + c.httpClient.Transport = transport.New(c.httpClient.Transport, c.debugger) + + return c +} + +func WithHttpClient(httpClient *http.Client) Option { + return func(client *Client) { + client.httpClient = httpClient + } +} + +func (c *Client) Models(models chan<- client.LanguageModel) error { + close(models) + return nil +} + +func (c *Client) Completion(ci *model.CommandInstance, messageParts chan<- client.MessagePart) error { + + sendURL := fmt.Sprintf("%s://%s:%d/jobs", c.protocol, c.host, c.port) + uuid := uuid.New() + requestTime := time.Now() + + c.debugger.Message("request-time", fmt.Sprintf("%v", requestTime)) + + request := promptRequest{ + ID: uuid, + Prompt: ci.JoinedPromptsContent("\n\n"), + } + + c.debugger.Message("sent-prompt", request.Prompt) + + requestBytes, err := json.Marshal(request) + if err != nil { + return err + } + + // Create the HTTP request + req, err := http.NewRequest("POST", sendURL, bytes.NewBuffer(requestBytes)) + if err != nil { + return err + } + resp, err := c.httpClient.Do(req) + if err != nil { + return err + } + + // Read the response body + responseBytes, err := io.ReadAll(resp.Body) + if err != nil { + return err + } + + var promptResponse promptResponse + err = json.Unmarshal(responseBytes, &promptResponse) + if err != nil { + return err + } + + if promptResponse.Status == processing { + for { + statusURL := fmt.Sprintf("%s://%s:%d/jobs/status/%s", c.protocol, c.host, c.port, uuid.String()) + + req, err := http.NewRequest("GET", statusURL, nil) + if err != nil { + return err + } + resp, err := c.httpClient.Do(req) + if err != nil { + return err + } + responseBytes, err := io.ReadAll(resp.Body) + if err != nil { + return err + } + + var jobStatusResponse jobStatusResponse + err = json.Unmarshal(responseBytes, &jobStatusResponse) + if err != nil { + return err + } + + if jobStatusResponse.Status == finished { + break + } + time.Sleep(c.pollDuration) + } + + } + + jobUrl := fmt.Sprintf("%s://%s:%d/jobs/%s", c.protocol, c.host, c.port, uuid.String()) + req, err = http.NewRequest("GET", jobUrl, nil) + + resp, err = c.httpClient.Do(req) + if err != nil { + return err + } + + responseBytes, err = io.ReadAll(resp.Body) + if err != nil { + return err + } + + var jobResponse jobResponse + err = json.Unmarshal(responseBytes, &jobResponse) + if err != nil { + return err + } + + messageParts <- client.MessagePart{Delta: "", Type: "Start"} + messageParts <- client.MessagePart{Delta: jobResponse.Output, Type: "Part"} + messageParts <- client.MessagePart{Delta: "", Type: "Done"} + close(messageParts) + + return nil + +} diff --git a/client/llamagoremote/promptrequest.go b/client/llamagoremote/promptrequest.go new file mode 100644 index 0000000..86894c3 --- /dev/null +++ b/client/llamagoremote/promptrequest.go @@ -0,0 +1,8 @@ +package llamagoremote + +import "github.com/google/uuid" + +type promptRequest struct { + ID uuid.UUID `json:"id"` + Prompt string `json:"prompt"` +} diff --git a/client/llamagoremote/promptresponse.go b/client/llamagoremote/promptresponse.go new file mode 100644 index 0000000..39f815d --- /dev/null +++ b/client/llamagoremote/promptresponse.go @@ -0,0 +1,13 @@ +package llamagoremote + +import ( + "github.com/google/uuid" + "time" +) + +type promptResponse struct { + ID uuid.UUID `json:"id"` + Prompt string `json:"prompt"` + Created time.Time `json:"created"` + Status status `json:"status"` +} diff --git a/client/llamagoremote/protocol.go b/client/llamagoremote/protocol.go new file mode 100644 index 0000000..7ff87c8 --- /dev/null +++ b/client/llamagoremote/protocol.go @@ -0,0 +1,24 @@ +package llamagoremote + +import "fmt" + +type Protocol string + +const ( + HttpProtocol Protocol = "http" + HttpsProtocol Protocol = "https" +) + +var protocolMap = map[string]Protocol{ + "http": HttpProtocol, + "https": HttpsProtocol, +} + +func ParseProtocol(protocolStr string) (Protocol, error) { + protocol, found := protocolMap[protocolStr] + if found { + return protocol, nil + } else { + return "", fmt.Errorf("protocol: '%s' not found", protocolStr) + } +} diff --git a/client/llamagoremote/status.go b/client/llamagoremote/status.go new file mode 100644 index 0000000..9096791 --- /dev/null +++ b/client/llamagoremote/status.go @@ -0,0 +1,26 @@ +package llamagoremote + +import ( + "fmt" +) + +type status string + +const ( + processing status = "processing" + finished status = "finished" +) + +var statusMap = map[string]status{ + "processing": processing, + "finished": finished, +} + +func parseStatus(statusStr string) (status, error) { + status, found := statusMap[statusStr] + if found { + return status, nil + } else { + return "", fmt.Errorf("status: '%s' not found", statusStr) + } +} diff --git a/client/openai/openaiclient.go b/client/openai/openaiclient.go index d1678bc..782b9e7 100644 --- a/client/openai/openaiclient.go +++ b/client/openai/openaiclient.go @@ -9,6 +9,7 @@ import ( "github.com/spandigitial/codeassistant/client" "github.com/spandigitial/codeassistant/client/debugger" "github.com/spandigitial/codeassistant/model" + "github.com/spandigitial/codeassistant/transport" "github.com/spf13/viper" "golang.org/x/time/rate" "io" @@ -17,7 +18,7 @@ import ( "time" ) -type OpenAiClient struct { +type Client struct { apiKey string debugger *debugger.Debugger rateLimiter *rate.Limiter @@ -26,10 +27,10 @@ type OpenAiClient struct { userAgent *string } -type Option func(client *OpenAiClient) +type Option func(client *Client) -func New(apiKey string, debugger *debugger.Debugger, options ...Option) *OpenAiClient { - c := &OpenAiClient{ +func New(apiKey string, debugger *debugger.Debugger, options ...Option) *Client { + c := &Client{ apiKey: apiKey, debugger: debugger, } @@ -42,30 +43,32 @@ func New(apiKey string, debugger *debugger.Debugger, options ...Option) *OpenAiC c.httpClient = http.DefaultClient } + c.httpClient.Transport = transport.New(c.httpClient.Transport, c.debugger) + return c } func WithHttpClient(httpClient *http.Client) Option { - return func(client *OpenAiClient) { + return func(client *Client) { client.httpClient = httpClient } } func WithUser(user string) Option { - return func(client *OpenAiClient) { + return func(client *Client) { client.user = &user } } func WithUserAgent(userAgent string) Option { - return func(client *OpenAiClient) { + return func(client *Client) { client.userAgent = &userAgent } } var dataRegex = regexp.MustCompile("data: (\\{.+\\})\\w?") -func (c *OpenAiClient) Models(models chan<- client.LanguageModel) error { +func (c *Client) Models(models chan<- client.LanguageModel) error { url := "https://api.openai.com/v1/models" requestTime := time.Now() @@ -112,7 +115,7 @@ func (c *OpenAiClient) Models(models chan<- client.LanguageModel) error { return nil } -func (c *OpenAiClient) Completion(commandInstance *model.CommandInstance, messageParts chan<- client.MessagePart) error { +func (c *Client) Completion(commandInstance *model.CommandInstance, messageParts chan<- client.MessagePart) error { url := "https://api.openai.com/v1/chat/completions" for _, prompt := range commandInstance.Prompts { diff --git a/client/vertexai/accesstoken.go b/client/vertexai/accesstoken.go index b2ac908..c6ee714 100644 --- a/client/vertexai/accesstoken.go +++ b/client/vertexai/accesstoken.go @@ -1,13 +1,12 @@ package vertexai import ( - "github.com/spf13/viper" "os/exec" "strings" ) -func generateAccessToken() (string, error) { - out, err := exec.Command(viper.GetString("gcloudBinary"), "auth", "print-access-token").Output() +func generateAccessToken(gcloudBinaryPath string) (string, error) { + out, err := exec.Command(gcloudBinaryPath, "auth", "print-access-token").Output() if err != nil { return "", err } diff --git a/client/vertexai/vertexaiclient.go b/client/vertexai/vertexaiclient.go index bb7f09c..6a95970 100644 --- a/client/vertexai/vertexaiclient.go +++ b/client/vertexai/vertexaiclient.go @@ -7,6 +7,7 @@ import ( "github.com/spandigitial/codeassistant/client" "github.com/spandigitial/codeassistant/client/debugger" "github.com/spandigitial/codeassistant/model" + "github.com/spandigitial/codeassistant/transport" "io" "net/http" "time" @@ -23,8 +24,8 @@ type Client struct { type Option func(client *Client) -func New(projectId string, location string, model string, debugger *debugger.Debugger, options ...Option) *Client { - accessToken, _ := generateAccessToken() +func New(gcloudBinaryPath string, projectId string, location string, model string, debugger *debugger.Debugger, options ...Option) *Client { + accessToken, _ := generateAccessToken(gcloudBinaryPath) c := &Client{ accessToken: accessToken, projectId: projectId, @@ -41,6 +42,8 @@ func New(projectId string, location string, model string, debugger *debugger.Deb c.httpClient = http.DefaultClient } + c.httpClient.Transport = transport.New(c.httpClient.Transport, c.debugger) + return c } @@ -76,7 +79,7 @@ func (c *Client) Completion(commandInstance *model.CommandInstance, messageParts request := predictRequest{ Instances: []instance{{ - Content: commandInstance.JoinedPromptsContent("\n"), + Content: commandInstance.JoinedPromptsContent("\n\n"), }}, Parameters: parameters, } diff --git a/cmd/root.go b/cmd/root.go index 42ed1f1..a9dceac 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -57,6 +57,18 @@ func init() { // will be global for your application. rootCmd.PersistentFlags().StringVar(&cfgFile, "config", "", "config file (default is $HOME/.codeassistant.yaml)") + rootCmd.PersistentFlags().String("llamaCppBinaryPath", "", "Path to LlamaCpp binary") + if err := viper.BindPFlag("llamaCppBinaryPath", rootCmd.PersistentFlags().Lookup("llamaCppBinaryPath")); err != nil { + log.Fatal("Unable to find flag llamaCppBinaryPath", err) + } + rootCmd.PersistentFlags().String("llamaCppModelPath", "", "Path to LlamaCpp model") + if err := viper.BindPFlag("llamaCppModelPath", rootCmd.PersistentFlags().Lookup("llamaCppModelPath")); err != nil { + log.Fatal("Unable to find flag llamaCppModelPath", err) + } + rootCmd.PersistentFlags().Int("llamaCppContextSize", 256, "LlammaCpp context size") + if err := viper.BindPFlag("llamaCppContextSize", rootCmd.PersistentFlags().Lookup("llamaCppContextSize")); err != nil { + log.Fatal("Unable to find flag llamaCppContextSize", err) + } rootCmd.PersistentFlags().String("openAiApiKey", "", "OpenAI API Key") if err := viper.BindPFlag("openAiApiKey", rootCmd.PersistentFlags().Lookup("openAiApiKey")); err != nil { log.Fatal("Unable to find flag openAiApiKey", err) diff --git a/cmd/run.go b/cmd/run.go index 40dbf26..e788e3c 100644 --- a/cmd/run.go +++ b/cmd/run.go @@ -37,10 +37,11 @@ var runPromptsCmd = &cobra.Command{ } llmClient = openai.New(openAiApiKey, debugger, openai.WithUser(user), openai.WithUserAgent(userAgent)) case "vertexai": + gcloudBinaryPath := viper.GetString("gcloudBinary") vertexAiProjectId := viper.GetString("vertexAiProjectId") vertexAiLocation := viper.GetString("vertexAiLocation") vertexAiModel := viper.GetString("vertexAiModel") - llmClient = vertexai.New(vertexAiProjectId, vertexAiLocation, vertexAiModel, debugger) + llmClient = vertexai.New(gcloudBinaryPath, vertexAiProjectId, vertexAiLocation, vertexAiModel, debugger) } f := bufio.NewWriter(os.Stdout) defer f.Flush() diff --git a/cmd/serve.go b/cmd/serve.go index ddcb1ba..15c2f72 100644 --- a/cmd/serve.go +++ b/cmd/serve.go @@ -104,10 +104,11 @@ to quickly create a Cobra application.`, } llmClient = openai.New(openAiApiKey, debugger, openai.WithUser(user), openai.WithUserAgent(userAgent)) case "vertexai": + gcloudBinaryPath := viper.GetString("gcloudBinary") vertexAiProjectId := viper.GetString("vertexAiProjectId") vertexAiLocation := viper.GetString("vertexAiLocation") vertexAiModel := viper.GetString("vertexAiModel") - llmClient = vertexai.New(vertexAiProjectId, vertexAiLocation, vertexAiModel, debugger) + llmClient = vertexai.New(gcloudBinaryPath, vertexAiProjectId, vertexAiLocation, vertexAiModel, debugger) } uuid := uuid.New() messageParts := make(chan client.MessagePart) diff --git a/transport/debuggerroundtrip.go b/transport/debuggerroundtrip.go index 868cc48..5141145 100644 --- a/transport/debuggerroundtrip.go +++ b/transport/debuggerroundtrip.go @@ -11,9 +11,9 @@ type DebuggerRoundtrip struct { debugger *debugger.Debugger } -func New(debugger *debugger.Debugger) *DebuggerRoundtrip { +func New(transport http.RoundTripper, debugger *debugger.Debugger) *DebuggerRoundtrip { return &DebuggerRoundtrip{ - transport: http.DefaultTransport, + transport: transport, debugger: debugger, } }