Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions cot-macros/src/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -176,8 +176,8 @@ impl ModelBuilder {
let fields_as_get_values = &self.fields_as_get_values;

quote! {
#[#crate_ident::__private::async_trait]
#[automatically_derived]
#[#orm_ident::async_trait]
impl #orm_ident::Model for #name {
type Fields = #fields_struct_name;
type PrimaryKey = #pk_type;
Expand Down Expand Up @@ -225,11 +225,11 @@ impl ModelBuilder {
}

async fn get_by_primary_key<DB: #orm_ident::DatabaseBackend>(
db: &DB,
mut db: DB,
pk: Self::PrimaryKey,
) -> #orm_ident::Result<Option<Self>> {
#orm_ident::query!(Self, $#pk_field_name == pk)
.get(db)
.get(&mut db)
.await
}
}
Expand Down
2 changes: 1 addition & 1 deletion cot/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ subtle = { workspace = true, features = ["std"] }
swagger-ui-redist = { workspace = true, optional = true }
thiserror.workspace = true
time.workspace = true
tokio = { workspace = true, features = ["macros", "rt-multi-thread", "signal", "fs", "io-util"] }
tokio = { workspace = true, features = ["macros", "rt-multi-thread", "signal", "fs", "io-util", "sync"] }
toml = { workspace = true, features = ["parse", "serde"] }
tower = { workspace = true, features = ["util"] }
tower-livereload = { workspace = true, optional = true }
Expand Down
30 changes: 16 additions & 14 deletions cot/src/auth/db.rs
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ impl DatabaseUser {
/// # }
/// ```
pub async fn create_user<DB: DatabaseBackend, T: Into<String>, U: Into<Password>>(
db: &DB,
mut db: DB,
username: T,
password: U,
) -> Result<Self> {
Expand All @@ -110,7 +110,9 @@ impl DatabaseUser {
})?;

let mut user = Self::new(Auto::auto(), username, &password.into());
user.insert(db).await.map_err(AuthError::backend_error)?;
user.insert(&mut db)
.await
.map_err(AuthError::backend_error)?;

Ok(user)
}
Expand Down Expand Up @@ -155,9 +157,9 @@ impl DatabaseUser {
/// # Ok(())
/// # }
/// ```
pub async fn get_by_id<DB: DatabaseBackend>(db: &DB, id: i64) -> Result<Option<Self>> {
pub async fn get_by_id<DB: DatabaseBackend>(mut db: DB, id: i64) -> Result<Option<Self>> {
let db_user = query!(DatabaseUser, $id == id)
.get(db)
.get(&mut db)
.await
.map_err(AuthError::backend_error)?;

Expand Down Expand Up @@ -201,14 +203,14 @@ impl DatabaseUser {
/// # }
/// ```
pub async fn get_by_username<DB: DatabaseBackend>(
db: &DB,
mut db: DB,
username: &str,
) -> Result<Option<Self>> {
let username = LimitedString::<MAX_USERNAME_LENGTH>::new(username).map_err(|_| {
AuthError::backend_error(CreateUserError::UsernameTooLong(username.len()))
})?;
let db_user = query!(DatabaseUser, $username == username)
.get(db)
.get(&mut db)
.await
.map_err(AuthError::backend_error)?;

Expand All @@ -221,7 +223,7 @@ impl DatabaseUser {
///
/// Returns an error if there was an error querying the database.
pub async fn authenticate<DB: DatabaseBackend>(
db: &DB,
mut db: DB,
credentials: &DatabaseUserCredentials,
) -> Result<Option<Self>> {
let username = credentials.username();
Expand All @@ -230,7 +232,7 @@ impl DatabaseUser {
AuthError::backend_error(CreateUserError::UsernameTooLong(username.len()))
})?;
let user = query!(DatabaseUser, $username == username_limited)
.get(db)
.get(&mut db)
.await
.map_err(AuthError::backend_error)?;

Expand All @@ -240,7 +242,7 @@ impl DatabaseUser {
PasswordVerificationResult::Ok => Ok(Some(user)),
PasswordVerificationResult::OkObsolete(new_hash) => {
user.password = new_hash;
user.save(db).await.map_err(AuthError::backend_error)?;
user.save(&mut db).await.map_err(AuthError::backend_error)?;
Ok(Some(user))
}
PasswordVerificationResult::Invalid => Ok(None),
Expand Down Expand Up @@ -624,7 +626,7 @@ mod tests {
let username = "testuser".to_string();
let password = Password::new("password123");

let user = DatabaseUser::create_user(&mock_db, username.clone(), &password)
let user = DatabaseUser::create_user(&mut mock_db, username.clone(), &password)
.await
.unwrap();
assert_eq!(user.username(), username);
Expand All @@ -644,7 +646,7 @@ mod tests {
.expect_get::<DatabaseUser>()
.returning(move |_| Ok(Some(user.clone())));

let result = DatabaseUser::get_by_id(&mock_db, 1).await.unwrap();
let result = DatabaseUser::get_by_id(&mut mock_db, 1).await.unwrap();
assert!(result.is_some());
assert_eq!(result.unwrap().username(), "testuser");
}
Expand All @@ -665,7 +667,7 @@ mod tests {

let credentials =
DatabaseUserCredentials::new("testuser".to_string(), Password::new("password123"));
let result = DatabaseUser::authenticate(&mock_db, &credentials)
let result = DatabaseUser::authenticate(&mut mock_db, &credentials)
.await
.unwrap();
assert!(result.is_some());
Expand All @@ -683,7 +685,7 @@ mod tests {

let credentials =
DatabaseUserCredentials::new("testuser".to_string(), Password::new("password123"));
let result = DatabaseUser::authenticate(&mock_db, &credentials)
let result = DatabaseUser::authenticate(&mut mock_db, &credentials)
.await
.unwrap();
assert!(result.is_none());
Expand All @@ -705,7 +707,7 @@ mod tests {

let credentials =
DatabaseUserCredentials::new("testuser".to_string(), Password::new("invalid"));
let result = DatabaseUser::authenticate(&mock_db, &credentials)
let result = DatabaseUser::authenticate(&mut mock_db, &credentials)
.await
.unwrap();
assert!(result.is_none());
Expand Down
Loading
Loading