diff --git a/src/auto/mod.rs b/src/auto/mod.rs index 0a3025f..8b7b2c9 100644 --- a/src/auto/mod.rs +++ b/src/auto/mod.rs @@ -118,6 +118,7 @@ pub struct ParsedResponse { raw: String } +#[allow(dead_code)] pub fn try_parse_yaml(llm: &LLM, tries: usize, max_tokens: Option) -> Result, Box> { try_parse_base(llm, tries, max_tokens, "yml", |str| serde_yaml::from_str(str).map_err(|el| Box::new(el) as Box)) } diff --git a/src/config/mod.rs b/src/config/mod.rs index c176e45..fe1bd6e 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -4,7 +4,7 @@ use colored::Colorize; use serde::{Serialize, Deserialize}; use serde_json::Value; -use crate::{CommandContext, LLM, Plugin, create_browse, create_google, create_filesystem, create_shutdown, create_wolfram, create_chatgpt, create_news, create_wikipedia, create_none, LLMProvider, create_model_chatgpt, Agents, LLMModel, create_model_llama, AgentInfo, MemoryProvider, create_memory_local, create_memory_qdrant, MemorySystem, create_memory_redis}; +use crate::{CommandContext, LLM, Plugin, create_browse, create_google, create_filesystem, create_shutdown, create_wolfram, create_chatgpt, create_news, create_wikipedia, create_none, LLMProvider, create_model_chatgpt, Agents, LLMModel, create_model_llama, AgentInfo, MemoryProvider, create_memory_local, create_memory_qdrant, MemorySystem, create_memory_redis, create_model_palm2}; mod default; pub use default::*; @@ -102,7 +102,8 @@ pub fn list_plugins() -> Vec { pub fn create_llm_providers() -> Vec> { vec![ create_model_chatgpt(), - create_model_llama() + create_model_llama(), + create_model_palm2() ] } diff --git a/src/llms/mod.rs b/src/llms/mod.rs index cf6c049..5bab144 100644 --- a/src/llms/mod.rs +++ b/src/llms/mod.rs @@ -1,8 +1,10 @@ mod chatgpt; +mod palm2; mod local; pub use chatgpt::*; pub use local::*; +pub use palm2::*; use tokio::runtime::Runtime; use std::{error::Error, fmt::Display}; diff --git a/src/llms/palm2/api.rs b/src/llms/palm2/api.rs new file mode 100644 index 0000000..b03cee2 --- /dev/null +++ b/src/llms/palm2/api.rs @@ -0,0 +1,70 @@ +use reqwest::{Client}; +use std::error::Error; + +use crate::{CountTokensRequest, TokenCountResponse, EmbedTextRequest, EmbeddingResponse, Embedding, MessagePrompt, GenerateMessageResponse, GenerateTextRequest, GenerateTextResponse, GCPModel, ListModelResponse}; + +pub struct ApiClient { + client: Client, + base_url: String, + api_key: String +} + +impl ApiClient { + pub fn new(base_url: String, api_key: String) -> Self { + Self { + client: Client::new(), + base_url, + api_key + } + } + + pub async fn count_message_tokens(&self, model: &str, message: CountTokensRequest) -> Result> { + let url = format!("{}/v1beta2/models/{}:countMessageTokens?key={}", self.base_url, model, self.api_key); + let response = self.client.post(&url).json(&message).send().await?; + + let token_count: TokenCountResponse = response.json().await?; + Ok(token_count) + } + + pub async fn embed_text(&self, model: &str, message: EmbedTextRequest) -> Result, Box> { + let url = format!("{}/v1beta2/models/{}:embedText?key={}", self.base_url, model, self.api_key); + let response = self.client.post(&url).json(&message).send().await?; + + let embedding: EmbeddingResponse = response.json().await?; + Ok(embedding.embedding.unwrap_or(Embedding { + value: vec![] + }).value) + } + + pub async fn generate_message(&self, model: &str, prompt: MessagePrompt) -> Result> { + let url = format!("{}/v1beta2/models/{}:generateMessage?key={}", self.base_url, model, self.api_key); + let response = self.client.post(&url).json(&prompt).send().await?; + + let message: GenerateMessageResponse = response.json().await?; + Ok(message) + } + + pub async fn generate_text(&self, model: &str, message: GenerateTextRequest) -> Result> { + let url = format!("{}/v1beta2/models/{}:generateText?key={}", self.base_url, model, self.api_key); + let response = self.client.post(&url).json(&message).send().await?; + + let text_response: GenerateTextResponse = response.json().await?; + Ok(text_response) + } + + pub async fn get_model(&self, name: &str) -> Result> { + let url = format!("{}/v1beta2/models/{}?key={}", self.base_url, name, self.api_key); + let response = self.client.get(&url).send().await?; + + let model: GCPModel = response.json().await?; + Ok(model) + } + + pub async fn list_models(&self) -> Result> { + let url = format!("{}/v1beta2/models?key={}", self.base_url, self.api_key); + let response = self.client.get(&url).send().await?; + + let models: ListModelResponse = response.json().await?; + Ok(models) + } +} \ No newline at end of file diff --git a/src/llms/palm2/mod.rs b/src/llms/palm2/mod.rs new file mode 100644 index 0000000..4ee459a --- /dev/null +++ b/src/llms/palm2/mod.rs @@ -0,0 +1,7 @@ +mod system; +mod api; +mod types; + +pub use system::*; +pub use api::*; +pub use types::*; \ No newline at end of file diff --git a/src/llms/palm2/system.rs b/src/llms/palm2/system.rs new file mode 100644 index 0000000..a962236 --- /dev/null +++ b/src/llms/palm2/system.rs @@ -0,0 +1,139 @@ +use std::{error::Error, thread::sleep, time::Duration}; + +use async_trait::async_trait; +use serde::{Serialize, Deserialize}; +use serde_json::Value; +use tokio::runtime::Runtime; + +use crate::{LLMProvider, Message, LLMModel, ApiClient, MessagePrompt, PALMMessage, GenerateTextRequest, TextPrompt, EmbedTextRequest, CountTokensRequest, GenerateTextResponse}; + +pub struct PALM2 { + pub model: String, + pub embedding_model: String, + pub client: ApiClient +} + +#[async_trait] +impl LLMModel for PALM2 { + async fn get_response(&self, messages: &[Message], max_tokens: Option, temperature: Option) -> Result> { + let palm_messages_string = messages + .iter() + .map(|el| el.content()) + .collect::>() + .join("\n"); + + let text_request = GenerateTextRequest { + prompt: TextPrompt { + text: palm_messages_string + }, + safety_settings: vec![], + stop_sequences: vec![], + temperature: temperature.unwrap_or(1.0) as f64, + candidate_count: 1, + max_output_tokens: max_tokens.unwrap_or(1000) as i32, + top_p: 0.95, + top_k: 40, + }; + + let response_message: GenerateTextResponse = self.client.generate_text(&self.model, text_request).await?; + + let response = response_message.candidates.unwrap_or(vec![]); + + let response_text = response + .iter() + .map(|el| el.output.clone()) + .collect::>() + .join(" "); + + Ok(response_text) + } + + async fn get_base_embed(&self, text: &str) -> Result, Box> { + let embedding_response = self.client.embed_text(&self.embedding_model, + EmbedTextRequest { + text: text.to_string() + }).await?; + + Ok(embedding_response) + } + + fn get_tokens_remaining(&self, messages: &[Message]) -> Result> { + let all_messages: Vec = messages.iter().map(|el| PALMMessage { + author: None, + content: el.content().to_string(), + citation_metadata: None + } + ).collect::>(); + + let count_tokens_request = CountTokensRequest { + prompt: MessagePrompt { context: "".to_string(), examples: vec![], messages: all_messages } + }; + + let runtime = tokio::runtime::Runtime::new()?; + + let gcp_model = runtime.block_on(self.client.get_model(&self.model))?; + let token_count = runtime.block_on(self.client.count_message_tokens(&self.model, count_tokens_request))?; + let max_tokens = gcp_model.input_token_limit; + + let tokens_remaining = max_tokens.checked_sub(token_count.token_count.unwrap_or(0) as i32) + .ok_or_else(|| "Token count exceeded the maximum limit.")?; + + Ok(tokens_remaining as usize) + } +} + +#[derive(Serialize, Deserialize)] +pub struct PALM2Config { + #[serde(rename = "api key")] pub api_key: String, + pub model: Option, + #[serde(rename = "api base")] pub api_base: Option, + #[serde(rename = "embedding model")] pub embedding_model: Option, +} + +pub struct PALM2Provider; + +#[async_trait] +impl LLMProvider for PALM2Provider { + fn is_enabled(&self) -> bool { + true + } + + fn get_name(&self) -> &str { + "palm2" + } + + fn create(&self, value: Value) -> Result, Box> { + let rt = Runtime::new().expect("Failed to create Tokio runtime"); + + let config: PALM2Config = serde_json::from_value(value)?; + + let client = ApiClient::new(config.api_base.unwrap_or("https://generativelanguage.googleapis.com".to_owned()), config.api_key); + + let all_messages: Vec = vec![PALMMessage { + author: None, + content: "count my tokens please palm".to_string(), + citation_metadata: None + }]; + + // Easy way to immediately test api call on startup + let models_response = rt.block_on(async { + client.count_message_tokens("text-bison-001", CountTokensRequest { + prompt: MessagePrompt { context: "".to_string(), examples: vec![], messages: all_messages } + }).await + })?; + + println!("model: {:?}", models_response); + sleep(Duration::new(10, 0)); + + Ok(Box::new(PALM2 { + model: config.model.unwrap_or("text-bison-001".to_string()), + embedding_model: config.embedding_model.unwrap_or("embedding-gecko-001".to_string()), + client + })) + } +} + +pub fn create_model_palm2() -> Box { + Box::new(PALM2Provider) +} + diff --git a/src/llms/palm2/types.rs b/src/llms/palm2/types.rs new file mode 100644 index 0000000..395bcd0 --- /dev/null +++ b/src/llms/palm2/types.rs @@ -0,0 +1,213 @@ +use serde::{Serialize, Deserialize}; + +use crate::Message; + + + +#[derive(Serialize, Deserialize, Debug)] +#[serde(rename_all = "camelCase")] +pub struct CitationSource { + pub start_index: i32, + pub end_index: i32, + pub uri: String, + pub license: String, +} + +#[derive(Serialize, Deserialize, Debug)] +#[serde(rename_all = "camelCase")] +pub struct CitationMetadata { + pub citation_sources: Vec, +} + +#[derive(Serialize, Deserialize, Debug)] +#[serde(rename_all = "camelCase")] +pub struct PALMMessage { + pub author: Option, + pub content: String, + pub citation_metadata: Option, +} + +impl From for PALMMessage { + fn from(message: Message) -> Self { + let content = match message { + Message::User(content) | Message::Assistant(content) | Message::System(content) => content, + }; + + PALMMessage { + author: None, + content, + citation_metadata: None, + } + } +} + +#[derive(Serialize, Deserialize, Debug)] +pub struct Example { + pub input: PALMMessage, + pub output: PALMMessage, +} + +#[derive(Serialize, Deserialize, Debug)] +pub struct MessagePrompt { + pub context: String, + pub examples: Vec, + pub messages: Vec, +} + +#[derive(Serialize, Deserialize, Debug)] +#[serde(rename_all = "camelCase")] +pub struct GenerateMessageResponse { + pub prompt: MessagePrompt, + pub temperature: f64, + pub candidate_count: i32, + pub top_p: f64, + pub top_k: i32, +} + +#[derive(Serialize, Deserialize, Debug)] +pub struct CountTokensRequest { + pub prompt: MessagePrompt, +} + +#[derive(Serialize, Deserialize, Debug)] +#[serde(rename_all = "camelCase")] +pub struct TextCompletion { + pub output: String, + pub safety_ratings: Vec, + pub citation_metadata: Option, +} + +#[derive(Serialize, Deserialize, Debug)] +#[serde(rename_all = "SCREAMING_SNAKE_CASE")] +pub enum BlockedReason { + BlockedReasonUnspecified, + Safety, + Other, +} + +#[derive(Serialize, Deserialize, Debug)] +pub struct ContentFilter { + pub reason: BlockedReason, + pub message: Option, +} + +#[derive(Serialize, Deserialize, Debug)] +#[serde(rename_all = "SCREAMING_SNAKE_CASE")] +pub enum HarmCategory { + HarmCategoryUnspecified, + HarmCategoryDerogatory, + HarmCategoryToxicity, + HarmCategoryViolence, + HarmCategorySexual, + HarmCategoryMedical, + HarmCategoryDangerous, +} + +#[derive(Serialize, Deserialize, Debug)] +#[serde(rename_all = "SCREAMING_SNAKE_CASE")] +pub enum HarmBlockThreshold { + HarmBlockThresholdUnspecified, + BlockLowAndAbove, + BlockMediumAndAbove, + BlockOnlyHigh, + BlockNone, +} + +#[derive(Serialize, Deserialize, Debug)] +#[serde(rename_all = "SCREAMING_SNAKE_CASE")] +pub enum HarmProbability { + HarmProbabilityUnspecified, + Negligible, + Low, + Medium, + High, +} + +#[derive(Serialize, Deserialize, Debug)] +pub struct SafetySetting { + pub category: HarmCategory, + pub threshold: HarmBlockThreshold, +} + +#[derive(Serialize, Deserialize, Debug)] +pub struct SafetyRating { + pub category: HarmCategory, + pub probability: HarmProbability, +} + +#[derive(Serialize, Deserialize, Debug)] +pub struct SafetyFeedback { + pub rating: SafetyRating, + pub setting: SafetySetting, +} + +#[derive(Serialize, Deserialize, Debug)] +#[serde(rename_all = "camelCase")] +pub struct GenerateTextResponse { + pub candidates: Option>, + pub filters: Option>, + pub safety_feedback: Option>, +} + +#[derive(Serialize, Deserialize, Debug)] +pub struct EmbedTextRequest { + pub text: String +} + +#[derive(Serialize, Deserialize, Debug)] +#[serde(rename_all = "camelCase")] +pub struct GCPModel { + pub name: String, + pub base_model_id: Option, + pub version: String, + pub display_name: String, + pub description: String, + pub input_token_limit: i32, + pub output_token_limit: i32, + pub supported_generation_methods: Vec, + pub temperature: Option, + pub top_p: Option, + pub top_k: Option, +} + +#[derive(Serialize, Deserialize, Debug)] +#[serde(rename_all = "camelCase")] +pub struct TokenCountResponse { + pub token_count: Option, +} + +#[derive(Serialize, Deserialize, Debug)] +pub struct TextPrompt { + // Add the fields of TextPrompt object here + pub text: String +} + +#[derive(Serialize, Deserialize, Debug)] +#[serde(rename_all = "camelCase")] +pub struct GenerateTextRequest { + pub prompt: TextPrompt, + pub safety_settings: Vec, + pub stop_sequences: Vec, + pub temperature: f64, + pub candidate_count: i32, + pub max_output_tokens: i32, + pub top_p: f64, + pub top_k: i32, +} + +#[derive(Deserialize, Debug)] +pub struct Embedding { + pub value: Vec, +} + +#[derive(Deserialize, Debug)] +pub struct EmbeddingResponse { + pub embedding: Option, +} + +#[derive(Serialize, Deserialize, Debug)] +#[serde(rename_all = "camelCase")] +pub struct ListModelResponse { + pub models: Vec, + pub next_page_token: Option, +} \ No newline at end of file diff --git a/src/main.rs b/src/main.rs index c582280..c68e867 100644 --- a/src/main.rs +++ b/src/main.rs @@ -31,6 +31,7 @@ pub struct NewEndGoal { #[serde(rename = "new end goal")] new_end_goal: String } +#[allow(dead_code)] fn debug_yaml(results: &str) -> Result<(), Box> { let json: Value = serde_json::from_str(&results)?; let mut yaml: String = serde_yaml::to_string(&json)?;