Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion refact-agent/engine/src/background_tasks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ pub async fn start_background_tasks(gcx: Arc<ARwLock<GlobalContext>>) -> Backgro
tokio::spawn(crate::vecdb::vdb_highlev::vecdb_background_reload(gcx.clone())), // this in turn can create global_context::vec_db
tokio::spawn(crate::integrations::sessions::remove_expired_sessions_background_task(gcx.clone())),
tokio::spawn(crate::git::cleanup::git_shadow_cleanup_background_task(gcx.clone())),
tokio::spawn(crate::cloud::threads_sub::watch_threads_subscription(gcx.clone())),
]);
let ast = gcx.clone().read().await.ast_service.clone();
if let Some(ast_service) = ast {
Expand Down
2 changes: 1 addition & 1 deletion refact-agent/engine/src/cloud/subchat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ pub async fn subchat(
};

let thread_id = thread.ft_id.clone();
let connection_result = initialize_connection(&cmd_address_url, &api_key, &located_fgroup_id).await;
let connection_result = initialize_connection(&cmd_address_url, &api_key, &located_fgroup_id, &app_searchable_id).await;
let mut connection = match connection_result {
Ok(conn) => conn,
Err(err) => return Err(format!("Failed to initialize WebSocket connection: {}", err)),
Expand Down
12 changes: 0 additions & 12 deletions refact-agent/engine/src/cloud/threads_processing.rs
Original file line number Diff line number Diff line change
Expand Up @@ -337,25 +337,13 @@ pub async fn process_thread_event(
basic_info: BasicStuff,
cmd_address_url: String,
api_key: String,
app_searchable_id: String,
located_fgroup_id: String,
) -> Result<(), String> {
if thread_payload.ft_need_tool_calls == -1
|| thread_payload.owner_fuser_id != basic_info.fuser_id
|| !thread_payload.ft_locked_by.is_empty() {
return Ok(());
}
if let Some(ft_app_searchable) = thread_payload.ft_app_searchable.clone() {
if ft_app_searchable != app_searchable_id {
info!("thread `{}` has different `app_searchable` id, skipping it: {} != {}",
thread_payload.ft_id, app_searchable_id, ft_app_searchable
);
return Ok(());
}
} else {
info!("thread `{}` doesn't have the `app_searchable` id, skipping it", thread_payload.ft_id);
return Ok(());
}
if let Some(error) = thread_payload.ft_error.as_ref() {
info!("thread `{}` has the error: `{}`. Skipping it", thread_payload.ft_id, error);
return Ok(());
Expand Down
35 changes: 19 additions & 16 deletions refact-agent/engine/src/cloud/threads_sub.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,9 @@ pub struct BasicStuff {
pub workspaces: Vec<Value>,
}

// XXX use xxx_subs::filter for ft_app_capture
const THREADS_SUBSCRIPTION_QUERY: &str = r#"
subscription ThreadsPageSubs($located_fgroup_id: String!) {
threads_in_group(located_fgroup_id: $located_fgroup_id) {
subscription ThreadsPageSubs($located_fgroup_id: String!, $filter: [String!]) {
threads_in_group(located_fgroup_id: $located_fgroup_id, filter: $filter) {
news_action
news_payload_id
news_payload {
Expand Down Expand Up @@ -78,19 +77,23 @@ pub async fn watch_threads_subscription(gcx: Arc<ARwLock<GlobalContext>>) {
let restart_flag = gcx.read().await.threads_subscription_restart_flag.clone();
restart_flag.store(false, Ordering::SeqCst);
}
let located_fgroup_id = if let Some(located_fgroup_id) = gcx.read().await.active_group_id.clone() {
located_fgroup_id
} else {
warn!("no active group is set, skipping threads subscription");
tokio::time::sleep(Duration::from_secs(RECONNECT_DELAY_SECONDS)).await;
continue;
let (located_fgroup_id, app_searchable_id) = {
let gcx_locked = gcx.read().await;
let located_fgroup_id = if let Some(located_fgroup_id) = gcx_locked.active_group_id.clone() {
located_fgroup_id
} else {
warn!("no active group is set, skipping threads subscription");
tokio::time::sleep(Duration::from_secs(RECONNECT_DELAY_SECONDS)).await;
continue;
};
(located_fgroup_id, gcx_locked.app_searchable_id.clone())
};

info!(
"starting subscription for threads_in_group with fgroup_id=\"{}\"",
located_fgroup_id
"starting subscription for threads_in_group with fgroup_id=\"{}\" and app_searchable_id=\"{}\"",
located_fgroup_id, app_searchable_id
);
let connection_result = initialize_connection(&address_url, &api_key, &located_fgroup_id).await;
let connection_result = initialize_connection(&address_url, &api_key, &located_fgroup_id, &app_searchable_id).await;
let mut connection = match connection_result {
Ok(conn) => conn,
Err(err) => {
Expand Down Expand Up @@ -129,6 +132,7 @@ pub async fn initialize_connection(
cmd_address_url: &str,
api_key: &str,
located_fgroup_id: &str,
app_searchable_id: &str,
) -> Result<
futures::stream::SplitStream<
tokio_tungstenite::WebSocketStream<
Expand Down Expand Up @@ -208,7 +212,8 @@ pub async fn initialize_connection(
"payload": {
"query": THREADS_SUBSCRIPTION_QUERY,
"variables": {
"located_fgroup_id": located_fgroup_id
"located_fgroup_id": located_fgroup_id,
"filter": [format!("ft_app_searchable:eq:{}", app_searchable_id)]
}
}
});
Expand All @@ -231,7 +236,6 @@ async fn actual_subscription_loop(
located_fgroup_id: &str,
) -> Result<(), String> {
info!("cloud threads subscription started, waiting for events...");
let app_searchable_id = gcx.read().await.app_searchable_id.clone();
let basic_info = get_basic_info(cmd_address_url, api_key).await?;
while let Some(msg) = connection.next().await {
if gcx.clone().read().await.shutdown_flag.load(Ordering::SeqCst) {
Expand Down Expand Up @@ -266,11 +270,10 @@ async fn actual_subscription_loop(
let basic_info_clone = basic_info.clone();
let cmd_address_url_clone = cmd_address_url.to_string();
let api_key_clone = api_key.to_string();
let app_searchable_id_clone = app_searchable_id.clone();
let located_fgroup_id_clone = located_fgroup_id.to_string();
tokio::spawn(async move {
crate::cloud::threads_processing::process_thread_event(
gcx_clone, payload_clone, basic_info_clone, cmd_address_url_clone, api_key_clone, app_searchable_id_clone, located_fgroup_id_clone
gcx_clone, payload_clone, basic_info_clone, cmd_address_url_clone, api_key_clone, located_fgroup_id_clone
).await
});
} else {
Expand Down
78 changes: 57 additions & 21 deletions refact-agent/engine/src/constants.rs
Original file line number Diff line number Diff line change
@@ -1,35 +1,56 @@
use std::net::IpAddr;

use tracing::info;
use url::Url;

const BASE_REFACT_URL: &str = "app.refact.ai";

/// Extracts the host (and optional port) from a URL string, e.g.:
/// ws://app.refact.ai/v1/graphql -> app.refact.ai
/// https://example.com:8080/path -> example.com:8080
/// app.refact.ai -> app.refact.ai
fn extract_base_host(address: &str) -> String {
/// Extracts the host (and optional port) from a URL string and determines if the protocol is secure, e.g.:
/// ws://app.refact.ai/v1/graphql -> (app.refact.ai, Some(false))
/// https://example.com:8080/path -> (example.com:8080, Some(true))
/// app.refact.ai -> (app.refact.ai, None)
fn get_host_and_is_protocol_secure(address: &str) -> (String, Option<bool>) {
if let Ok(url) = Url::parse(address) {
if let Some(host) = url.host_str() {
return if let Some(port) = url.port() {
let host_with_port = if let Some(port) = url.port() {
format!("{}:{}", host, port)
} else {
host.to_string()
}
};

let is_secure = match url.scheme() {
"https" | "wss" => Some(true),
"http" | "ws" => Some(false),
_ => None,
};

return (host_with_port, is_secure);
}
}

let mut address = address;
for prefix in ["ws://", "wss://", "http://", "https://"] {
let mut is_secure = None;

for (prefix, secure) in [
("https://", Some(true)),
("wss://", Some(true)),
("http://", Some(false)),
("ws://", Some(false)),
] {
if let Some(stripped) = address.strip_prefix(prefix) {
address = stripped;
is_secure = secure;
break;
}
}
let address = address;
if let Some(idx) = address.find('/') {

let host = if let Some(idx) = address.find('/') {
address[..idx].to_string()
} else {
address.to_string()
}
};

(host, is_secure)
}

fn is_localhost(address: &str) -> bool {
Expand All @@ -50,8 +71,11 @@ fn is_localhost(address: &str) -> bool {
address.to_string()
}
};
if let Ok(ip) = address.parse::<IpAddr>() {
return ip.is_loopback();
}
match address.to_ascii_lowercase().as_str() {
"localhost" | "127.0.0.1" | "::1" | "[::1]" => true,
"localhost" | "127.0.0.1" | "::1" | "[::1]" | "host.docker.internal" => true,
_ => false,
}
}
Expand All @@ -60,9 +84,13 @@ pub fn get_cloud_url(cmd_address_url: &str) -> String {
let final_address = if cmd_address_url.to_lowercase() == "refact" {
format!("https://{}/v1", BASE_REFACT_URL)
} else {
let base_part = extract_base_host(cmd_address_url);
let protocol = if is_localhost(&base_part) { "http" } else { "https" };
format!("{}://{}/v1", protocol, base_part)
let (host, is_secure) = get_host_and_is_protocol_secure(cmd_address_url);
let protocol = match is_secure {
Some(true) => "https",
Some(false) => "http",
None => if is_localhost(&host) { "http" } else { "https" },
};
format!("{}://{}/v1", protocol, host)
};
info!("resolved cloud url: {}", final_address);
final_address
Expand All @@ -72,9 +100,13 @@ pub fn get_graphql_ws_url(cmd_address_url: &str) -> String {
let final_address = if cmd_address_url.to_lowercase() == "refact" {
format!("ws://{}/v1/graphql", BASE_REFACT_URL)
} else {
let base_part = extract_base_host(cmd_address_url);
let protocol = if is_localhost(&base_part) { "ws" } else { "wss" };
format!("{}://{}/v1/graphql", protocol, base_part)
let (host, is_secure) = get_host_and_is_protocol_secure(cmd_address_url);
let protocol = match is_secure {
Some(true) => "wss",
Some(false) => "ws",
None => if is_localhost(&host) { "ws" } else { "wss" },
};
format!("{}://{}/v1/graphql", protocol, host)
};
info!("resolved graphql ws url: {}", final_address);
final_address
Expand All @@ -84,9 +116,13 @@ pub fn get_graphql_url(cmd_address_url: &str) -> String {
let final_address = if cmd_address_url.to_lowercase() == "refact" {
format!("https://{}/v1/graphql", BASE_REFACT_URL)
} else {
let base_part = extract_base_host(cmd_address_url);
let protocol = if is_localhost(&base_part) { "http" } else { "https" };
format!("{}://{}/v1/graphql", protocol, base_part)
let (host, is_secure) = get_host_and_is_protocol_secure(cmd_address_url);
let protocol = match is_secure {
Some(true) => "https",
Some(false) => "http",
None => if is_localhost(&host) { "http" } else { "https" },
};
format!("{}://{}/v1/graphql", protocol, host)
};
info!("resolved graphql url: {}", final_address);
final_address
Expand Down
15 changes: 11 additions & 4 deletions refact-agent/engine/src/files_in_workspace.rs
Original file line number Diff line number Diff line change
Expand Up @@ -696,11 +696,18 @@ pub async fn on_workspaces_init(gcx: Arc<ARwLock<GlobalContext>>) -> i32
{
// Called from lsp and lsp_like
// Not called from main.rs as part of initialization
let folders = gcx.read().await.documents_state.workspace_folders.lock().unwrap().clone();
let old_app_searchable_id = gcx.read().await.app_searchable_id.clone();
let new_app_searchable_id = get_app_searchable_id(&folders);
let (folders, old_app_searchable_id, cmdline_app_searchable_id) = {
let gcx_locked = gcx.read().await;
let folders = gcx_locked.documents_state.workspace_folders.lock().unwrap().clone();
(
folders,
gcx_locked.app_searchable_id.clone(),
gcx_locked.cmdline.app_searchable_id.clone(),
)
};
let new_app_searchable_id = get_app_searchable_id(&folders, &cmdline_app_searchable_id);
if old_app_searchable_id != new_app_searchable_id {
gcx.write().await.app_searchable_id = get_app_searchable_id(&folders);
gcx.write().await.app_searchable_id = new_app_searchable_id;
crate::cloud::threads_sub::trigger_threads_subscription_restart(gcx.clone()).await;
}
watcher_init(gcx.clone()).await;
Expand Down
Loading