diff --git a/rust/src/pgsql/pgsql.rs b/rust/src/pgsql/pgsql.rs index 1873be5a7532..4099143df562 100644 --- a/rust/src/pgsql/pgsql.rs +++ b/rust/src/pgsql/pgsql.rs @@ -39,6 +39,8 @@ static mut PGSQL_MAX_TX: usize = 1024; pub enum PgsqlTransactionState { Init = 0, RequestReceived, + RequestDone, + ResponseReceived, ResponseDone, FlushedOut, } @@ -97,28 +99,30 @@ impl PgsqlTransaction { #[derive(Clone, Copy, Debug, PartialEq, Eq)] pub enum PgsqlStateProgress { IdleState, +// Related to Frontend-received messages // SSLRequestReceived, - SSLRejectedReceived, StartupMessageReceived, - SASLAuthenticationReceived, SASLInitialResponseReceived, - // SSPIAuthenticationReceived, // TODO implement SASLResponseReceived, + PasswordMessageReceived, + SimpleQueryReceived, + CancelRequestReceived, + ConnectionTerminated, +// Related to Backend-received messages // + SSLRejectedReceived, + // SSPIAuthenticationReceived, // TODO implement + SASLAuthenticationReceived, SASLAuthenticationContinueReceived, SASLAuthenticationFinalReceived, SimpleAuthenticationReceived, - PasswordMessageReceived, AuthenticationOkReceived, ParameterSetup, BackendKeyReceived, ReadyForQueryReceived, - SimpleQueryReceived, RowDescriptionReceived, DataRowReceived, CommandCompletedReceived, ErrorMessageReceived, - CancelRequestReceived, - ConnectionTerminated, #[cfg(test)] UnknownState, Finished, @@ -244,16 +248,29 @@ impl PgsqlState { /// /// As Pgsql transactions are bidirectional and may be comprised of several /// responses, we must track State progress to decide on tx completion - fn is_tx_completed(&self) -> bool { - if let PgsqlStateProgress::ReadyForQueryReceived - | PgsqlStateProgress::SSLRejectedReceived - | PgsqlStateProgress::SimpleAuthenticationReceived - | PgsqlStateProgress::SASLAuthenticationReceived - | PgsqlStateProgress::SASLAuthenticationContinueReceived - | PgsqlStateProgress::SASLAuthenticationFinalReceived - | PgsqlStateProgress::ConnectionTerminated - | PgsqlStateProgress::Finished = self.state_progress + fn is_tx_completed(&self, direction: Direction) -> bool { + if direction == Direction::ToClient { + if let PgsqlStateProgress::ReadyForQueryReceived + | PgsqlStateProgress::SSLRejectedReceived + | PgsqlStateProgress::SimpleAuthenticationReceived + | PgsqlStateProgress::SASLAuthenticationReceived + | PgsqlStateProgress::SASLAuthenticationContinueReceived + | PgsqlStateProgress::SASLAuthenticationFinalReceived + | PgsqlStateProgress::Finished = self.state_progress + { + true + } else { + false + } + } else if let PgsqlStateProgress::SSLRequestReceived + | PgsqlStateProgress::StartupMessageReceived + | PgsqlStateProgress::SimpleQueryReceived + | PgsqlStateProgress::PasswordMessageReceived + | PgsqlStateProgress::SASLInitialResponseReceived + | PgsqlStateProgress::SASLResponseReceived + | PgsqlStateProgress::CancelRequestReceived + | PgsqlStateProgress::ConnectionTerminated = self.state_progress { true } else { false @@ -341,16 +358,22 @@ impl PgsqlState { ); match PgsqlState::state_based_req_parsing(self.state_progress, start) { Ok((rem, request)) => { - sc_app_layer_parser_trigger_raw_stream_reassembly(flow, Direction::ToServer as i32); start = rem; + let mut temp_state = PgsqlStateProgress::IdleState; if let Some(state) = PgsqlState::request_next_state(&request) { self.state_progress = state; + temp_state = state; }; - let tx_completed = self.is_tx_completed(); + let tx_completed = self.is_tx_completed(Direction::ToServer); if let Some(tx) = self.find_or_create_tx() { tx.request = Some(request); if tx_completed { - tx.tx_state = PgsqlTransactionState::ResponseDone; + if temp_state == PgsqlStateProgress::ConnectionTerminated || temp_state == PgsqlStateProgress::CancelRequestReceived { + tx.tx_state = PgsqlTransactionState::ResponseDone; + } else { + tx.tx_state = PgsqlTransactionState::RequestDone; + } + sc_app_layer_parser_trigger_raw_stream_reassembly(flow, Direction::ToServer as i32); } } else { // If there isn't a new transaction, we'll consider Suri should move on @@ -471,15 +494,17 @@ impl PgsqlState { while !start.is_empty() { match PgsqlState::state_based_resp_parsing(self.state_progress, start) { Ok((rem, response)) => { - sc_app_layer_parser_trigger_raw_stream_reassembly(flow, Direction::ToClient as i32); start = rem; SCLogDebug!("Response is {:?}", &response); if let Some(state) = self.response_process_next_state(&response, flow) { self.state_progress = state; }; - let tx_completed = self.is_tx_completed(); + let tx_completed = self.is_tx_completed(Direction::ToClient); let curr_state = self.state_progress; if let Some(tx) = self.find_or_create_tx() { + if tx.tx_state == PgsqlTransactionState::Init { + tx.tx_state = PgsqlTransactionState::ResponseReceived; + } if curr_state == PgsqlStateProgress::DataRowReceived { tx.incr_row_cnt(); } else if curr_state == PgsqlStateProgress::CommandCompletedReceived @@ -498,6 +523,7 @@ impl PgsqlState { tx.responses.push(response); if tx_completed { tx.tx_state = PgsqlTransactionState::ResponseDone; + sc_app_layer_parser_trigger_raw_stream_reassembly(flow, Direction::ToClient as i32); } } } else { @@ -746,7 +772,7 @@ pub unsafe extern "C" fn rs_pgsql_register_parser() { parse_tc: rs_pgsql_parse_response, get_tx_count: rs_pgsql_state_get_tx_count, get_tx: rs_pgsql_state_get_tx, - tx_comp_st_ts: PgsqlTransactionState::RequestReceived as i32, + tx_comp_st_ts: PgsqlTransactionState::RequestDone as i32, tx_comp_st_tc: PgsqlTransactionState::ResponseDone as i32, tx_get_progress: rs_pgsql_tx_get_alstate_progress, get_eventinfo: None,