Skip to content

Commit

Permalink
Add database version file for katana-db (#1338)
Browse files Browse the repository at this point in the history
* add db version file

* add more test
  • Loading branch information
kariy authored and Larkooo committed Jan 2, 2024
1 parent b48657e commit f899716
Show file tree
Hide file tree
Showing 2 changed files with 166 additions and 4 deletions.
94 changes: 90 additions & 4 deletions crates/katana/storage/db/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,32 +1,49 @@
//! Code adapted from Paradigm's [`reth`](https://github.com/paradigmxyz/reth/tree/main/crates/storage/db) DB implementation.
use std::fs;
use std::path::Path;

use anyhow::Context;
use anyhow::{anyhow, Context};

pub mod codecs;
pub mod error;
pub mod mdbx;
pub mod models;
pub mod tables;
pub mod utils;
pub mod version;

use mdbx::{DbEnv, DbEnvKind};
use utils::is_database_empty;
use version::{check_db_version, create_db_version_file, DatabaseVersionError, CURRENT_DB_VERSION};

/// Initialize the database at the given path and returning a handle to the its
/// environment.
///
/// This will create the default tables, if necessary.
pub fn init_db<P: AsRef<Path>>(path: P) -> anyhow::Result<DbEnv> {
if is_database_empty(path.as_ref()) {
// TODO: create dir if it doesn't exist and insert db version file
std::fs::create_dir_all(path.as_ref()).with_context(|| {
fs::create_dir_all(&path).with_context(|| {
format!("Creating database directory at path {}", path.as_ref().display())
})?;
create_db_version_file(&path, CURRENT_DB_VERSION).with_context(|| {
format!("Inserting database version file at path {}", path.as_ref().display())
})?
} else {
// TODO: check if db version file exists and if it's compatible
match check_db_version(&path) {
Ok(_) => {}
Err(DatabaseVersionError::FileNotFound) => {
create_db_version_file(&path, CURRENT_DB_VERSION).with_context(|| {
format!(
"No database version file found. Inserting version file at path {}",
path.as_ref().display()
)
})?
}
Err(err) => return Err(anyhow!(err)),
}
}

let env = open_db(path)?;
env.create_tables()?;
Ok(env)
Expand All @@ -38,3 +55,72 @@ pub fn open_db<P: AsRef<Path>>(path: P) -> anyhow::Result<DbEnv> {
format!("Opening database in read-write mode at path {}", path.as_ref().display())
})
}

#[cfg(test)]
mod tests {

use std::fs;

use crate::init_db;
use crate::version::{default_version_file_path, get_db_version, CURRENT_DB_VERSION};

#[test]
fn initialize_db_in_empty_dir() {
let path = tempfile::tempdir().unwrap();
init_db(path.path()).unwrap();

let version_file = fs::File::open(default_version_file_path(path.path())).unwrap();
let actual_version = get_db_version(path.path()).unwrap();

assert!(
version_file.metadata().unwrap().permissions().readonly(),
"version file should set to read-only"
);
assert_eq!(actual_version, CURRENT_DB_VERSION);
}

#[test]
fn initialize_db_in_existing_db_dir() {
let path = tempfile::tempdir().unwrap();

init_db(path.path()).unwrap();
let version = get_db_version(path.path()).unwrap();

init_db(path.path()).unwrap();
let same_version = get_db_version(path.path()).unwrap();

assert_eq!(version, same_version);
}

#[test]
fn initialize_db_with_malformed_version_file() {
let path = tempfile::tempdir().unwrap();
let version_file_path = default_version_file_path(path.path());
fs::write(version_file_path, b"malformed").unwrap();

let err = init_db(path.path()).unwrap_err();
assert!(err.to_string().contains("Malformed database version file"));
}

#[test]
fn initialize_db_with_mismatch_version() {
let path = tempfile::tempdir().unwrap();
let version_file_path = default_version_file_path(path.path());
fs::write(version_file_path, 99u32.to_be_bytes()).unwrap();

let err = init_db(path.path()).unwrap_err();
assert!(err.to_string().contains("Database version mismatch"));
}

#[test]
fn initialize_db_with_missing_version_file() {
let path = tempfile::tempdir().unwrap();
init_db(path.path()).unwrap();

fs::remove_file(default_version_file_path(path.path())).unwrap();

init_db(path.path()).unwrap();
let actual_version = get_db_version(path.path()).unwrap();
assert_eq!(actual_version, CURRENT_DB_VERSION);
}
}
76 changes: 76 additions & 0 deletions crates/katana/storage/db/src/version.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
use std::array::TryFromSliceError;
use std::fs::{self};
use std::io::{Read, Write};
use std::mem;
use std::path::{Path, PathBuf};

/// Current version of the database.
pub const CURRENT_DB_VERSION: u32 = 0;

/// Name of the version file.
const DB_VERSION_FILE_NAME: &str = "db.version";

#[derive(Debug, thiserror::Error)]
pub enum DatabaseVersionError {
#[error("Database version file not found.")]
FileNotFound,
#[error(transparent)]
Io(#[from] std::io::Error),
#[error("Malformed database version file: {0}")]
MalformedContent(#[from] TryFromSliceError),
#[error("Database version mismatch. Expected version {expected}, found version {found}.")]
MismatchVersion { expected: u32, found: u32 },
}

/// Insert a version file at the given `path` with the specified `version`. If the `path` is a
/// directory, the version file will be created inside it. Otherwise, the version file will be
/// created exactly at `path`.
///
/// Ideally the version file should be included in the database directory.
///
/// # Errors
///
/// Will fail if all the directories in `path` has not already been created.
pub(super) fn create_db_version_file(
path: impl AsRef<Path>,
version: u32,
) -> Result<(), DatabaseVersionError> {
let path = path.as_ref();
let path = if path.is_dir() { default_version_file_path(path) } else { path.to_path_buf() };

let mut file = fs::File::create(path)?;
let mut permissions = file.metadata()?.permissions();
permissions.set_readonly(true);

file.set_permissions(permissions)?;
file.write_all(&version.to_be_bytes()).map_err(DatabaseVersionError::Io)
}

/// Check the version of the database at the given `path`.
///
/// Returning `Ok` if the version matches with [`CURRENT_DB_VERSION`], otherwise `Err` is returned.
pub(super) fn check_db_version(path: impl AsRef<Path>) -> Result<(), DatabaseVersionError> {
let version = get_db_version(path)?;
if version != CURRENT_DB_VERSION {
Err(DatabaseVersionError::MismatchVersion { expected: CURRENT_DB_VERSION, found: version })
} else {
Ok(())
}
}

/// Get the version of the database at the given `path`.
pub(super) fn get_db_version(path: impl AsRef<Path>) -> Result<u32, DatabaseVersionError> {
let path = path.as_ref();
let path = if path.is_dir() { default_version_file_path(path) } else { path.to_path_buf() };

let mut file = fs::File::open(path).map_err(|_| DatabaseVersionError::FileNotFound)?;
let mut buf: Vec<u8> = Vec::new();
file.read_to_end(&mut buf)?;

let bytes = <[u8; mem::size_of::<u32>()]>::try_from(buf.as_slice())?;
Ok(u32::from_be_bytes(bytes))
}

pub(super) fn default_version_file_path(path: &Path) -> PathBuf {
path.join(DB_VERSION_FILE_NAME)
}

0 comments on commit f899716

Please sign in to comment.