diff --git a/Cargo.toml b/Cargo.toml index b4408fe..6afa5f7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,28 +1,32 @@ [package] name = "pg_summarize" -version = "0.0.0" +version = "0.0.1" edition = "2021" [lib] -crate-type = ["cdylib"] +crate-type = ["cdylib", "rlib"] + +[[bin]] +name = "pgrx_embed_pg_summarize" +path = "src/bin/pgrx_embed.rs" [features] -default = ["pg13"] -pg11 = ["pgrx/pg11", "pgrx-tests/pg11" ] -pg12 = ["pgrx/pg12", "pgrx-tests/pg12" ] +default = ["pg18"] pg13 = ["pgrx/pg13", "pgrx-tests/pg13" ] pg14 = ["pgrx/pg14", "pgrx-tests/pg14" ] pg15 = ["pgrx/pg15", "pgrx-tests/pg15" ] pg16 = ["pgrx/pg16", "pgrx-tests/pg16" ] +pg17 = ["pgrx/pg17", "pgrx-tests/pg17" ] +pg18 = ["pgrx/pg18", "pgrx-tests/pg18" ] pg_test = [] [dependencies] -pgrx = "=0.11.4" +pgrx = "=0.16.1" reqwest = { version = "0.12.4", features = ["json", "blocking"] } serde_json = "1.0.117" [dev-dependencies] -pgrx-tests = "=0.11.4" +pgrx-tests = "=0.16.1" [profile.dev] panic = "unwind" @@ -32,3 +36,4 @@ panic = "unwind" opt-level = 3 lto = "fat" codegen-units = 1 + diff --git a/src/bin/pgrx_embed.rs b/src/bin/pgrx_embed.rs new file mode 100644 index 0000000..2216cee --- /dev/null +++ b/src/bin/pgrx_embed.rs @@ -0,0 +1 @@ +pgrx::pgrx_embed!(); \ No newline at end of file diff --git a/src/lib.rs b/src/lib.rs index fece950..225b698 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -3,7 +3,10 @@ use reqwest::blocking::Client; use reqwest::header::{HeaderMap, HeaderValue, AUTHORIZATION, CONTENT_TYPE}; use serde_json::json; -pgrx::pg_module_magic!(); +// Re-export pgrx for the binary +pub use pgrx; + +pg_module_magic!(); #[pg_extern] fn hello_pg_summarize() -> &'static str { @@ -12,29 +15,48 @@ fn hello_pg_summarize() -> &'static str { #[pg_extern] fn summarize(input: &str) -> String { - let api_key = Spi::get_one::<&str>("SELECT current_setting('pg_summarizer.api_key', true)") - .expect("failed to get 'pg_summarizer.api_key' setting") - .expect("got null for 'pg_summarizer.api_key' setting"); - - let model = match Spi::get_one::<&str>("SELECT current_setting('pg_summarizer.model', true)") { - Ok(Some(model_name)) => model_name, - _ => "gpt-3.5-turbo", - }; - - let prompt = match Spi::get_one::<&str>("SELECT current_setting('pg_summarizer.prompt', true)") - { - Ok(Some(prompt_str)) => prompt_str, - _ => { - "You are an AI summarizing tool. \ - Your purpose is to summarize the tag, \ - not to engage in conversation or discussion. \ - Please read the carefully. \ - Then, summarize the key points. \ - Focus on capturing the most important information as concisely as possible." - } - }; + let api_key = Spi::connect(|client| { + client + .select("SELECT current_setting('pg_summarizer.api_key', true)", None, &[])? + .first() + .get::(1) + .ok() + .flatten() + .ok_or(pgrx::spi::Error::InvalidPosition) + }) + .expect("failed to get 'pg_summarizer.api_key' setting"); + + let model = Spi::connect(|client| -> Result { + Ok(client + .select("SELECT current_setting('pg_summarizer.model', true)", None, &[])? + .first() + .get::(1) + .ok() + .flatten() + .unwrap_or_else(|| "gpt-3.5-turbo".to_string())) + }) + .expect("failed to get 'pg_summarizer.model' setting"); - match make_api_call(input, &api_key, model, prompt) { + let prompt = Spi::connect(|client| -> Result { + Ok(client + .select("SELECT current_setting('pg_summarizer.prompt', true)", None, &[])? + .first() + .get::(1) + .ok() + .flatten() + .unwrap_or_else(|| { + "You are an AI summarizing tool. \ + Your purpose is to summarize the tag, \ + not to engage in conversation or discussion. \ + Please read the carefully. \ + Then, summarize the key points. \ + Focus on capturing the most important information as concisely as possible." + .to_string() + })) + }) + .expect("failed to get 'pg_summarizer.prompt' setting"); + + match make_api_call(input, &api_key, &model, &prompt) { Ok(summary) => summary, Err(e) => panic!("Error: {}", e), } @@ -110,3 +132,4 @@ pub mod pg_test { vec![] } } +