diff --git a/database/index.js b/database/index.js
index 9b029af..dc55811 100644
--- a/database/index.js
+++ b/database/index.js
@@ -1,42 +1,38 @@
-import * as lancedb from "@lancedb/lancedb";
-import { get, post } from "../tools/request.js"
-import { Schema, Field, FixedSizeList, Int16, Float16, Utf8 } from "apache-arrow";
+import { connect } from "@lancedb/lancedb";
+import {
+ Schema, Field, FixedSizeList,
+ Float32, Utf8,
+ // eslint-disable-next-line
+ Table
+} from "apache-arrow";
+import { DATASET_TABLE, SYSTEM_TABLE } from "./types";
const uri = "/tmp/lancedb/";
-const db = await lancedb.connect(uri);
+const db = await connect(uri);
-const table = await db.createEmptyTable("rag_data", new Schema([
- new Field("id", new Int16()),
- new Field("vector", new FixedSizeList(384, new Field("item", new Float16(), true)), false),
- new Field("question", new Utf8()),
- new Field("answer", new Utf8())
-]), {
- // mode: "overwrite",
- existOk: true
-})
-
-export async function loadDataset(dataset_link) {
- const {rows, http_error} = await get('', {}, { URL: dataset_link })
- if(http_error) {
- return false;
- }
- await table.add(rows.map(({ row_id, row })=>{
- const { question, answer, question_embedding } = row;
- return { id: row_id, question, answer, vector: question_embedding }
- }))
- return true;
+export async function initDB(force = false) {
+ const open_options = force ? { mode: "overwrite" } : { existOk: true }
+ // create or re-open system table to store long-lasting data
+ await db.createEmptyTable(SYSTEM_TABLE, new Schema([
+ new Field("title", new Utf8()),
+ new Field("value", new Utf8())
+ ]), open_options)
+ // create or re-open dataset table
+ await db.createEmptyTable(DATASET_TABLE, new Schema([
+ new Field("vector", new FixedSizeList(384, new Field("item", new Float32(), true)), false),
+ new Field("dataset_name", new Utf8()),
+ new Field("question", new Utf8()),
+ new Field("answer", new Utf8())
+ ]), open_options)
}
-export async function searchByEmbedding(vector) {
- const record = await table.search(vector).limit(1).toArray();
- if(!record.length) return null;
- const { question, answer } = record[0];
- return { question, answer };
-}
+initDB();
-export async function searchByMessage(msg) {
- const { embedding } = await post('embedding', {body: {
- content: msg
- }}, { eng: "embedding" });
- return await searchByEmbedding(embedding);
+/**
+ * Open a table with table name
+ * @param {String} table_name table name to be opened
+ * @returns {Promise
} Promise containes the table object.
+ */
+export async function getTable(table_name) {
+ return await db.openTable(table_name)
}
\ No newline at end of file
diff --git a/database/rag-inference.js b/database/rag-inference.js
new file mode 100644
index 0000000..3e0d350
--- /dev/null
+++ b/database/rag-inference.js
@@ -0,0 +1,78 @@
+import { get, post } from "../tools/request.js";
+import { getTable } from "./index.js";
+import { DATASET_TABLE, SYSTEM_TABLE } from "./types.js";
+
+async function loadDatasetFromURL(dataset_name, dataset_url, system_table) {
+ system_table = system_table || await getTable(SYSTEM_TABLE);
+ const { rows, http_error } = await get('', {}, {URL: dataset_url});
+ if(http_error) return false;
+
+ await system_table.add([{ title: "loaded_dataset_name", value: dataset_name }]);
+
+ await (await getTable(DATASET_TABLE)).add(rows.map(({row})=>{
+ const { question, answer, question_embedding } = row;
+ return { question, answer, vector: question_embedding, dataset_name }
+ }))
+ return true;
+}
+
+/**
+ * Load a dataset from given url.
+ * * This will first check whether the dataset is loaded in database, if `force` not provided and it's loaded already, it won't load again.
+ * * The dataset format should be an array of object contains at least `question`, `answer` and `question_embedding` properties
+ * @param {String} dataset_name The dataset name to load
+ * @param {String} dataset_url The url of dataset to load
+ * @param {Boolean} force Specify whether to force load the dataset, default `false`.
+ * @returns {Promise} If cannot get the dataset, return `false`, otherwise return `true`
+ */
+export async function loadDataset(dataset_name, dataset_url, force = false) {
+ const system_table = await getTable(SYSTEM_TABLE)
+ if(!force) {
+ const loaded_dataset = await system_table.query()
+ .where(`title="loaded_dataset_name" AND value="${dataset_name}"`).toArray();
+ // check if the given dataset loaded, if not, load the dataset
+ return !!(loaded_dataset.length || await loadDatasetFromURL(dataset_name, dataset_url, system_table))
+ } else {
+ return await loadDatasetFromURL(dataset_name, dataset_url, system_table)
+ }
+}
+
+/**
+ * @typedef EmbeddingSearchResult
+ * @property {String} question The question from dataset
+ * @property {String} answer The answer from dataset
+ */
+
+/**
+ * Search in given dataset using provided embedding value to get Q/A pair
+ * @param {String} dataset_name The dataset name to be query from
+ * @param {Array} vector The embedding result to be searched
+ * @returns {Promise} If there's no result, returns null, otherwise returns the result
+ */
+export async function searchByEmbedding(dataset_name, vector) {
+ const embedding_result = (await (
+ await getTable(DATASET_TABLE)
+ ).search(vector).where(`dataset_name = "${dataset_name}"`)
+ .limit(1).toArray()).pop();
+
+ if(embedding_result) {
+ const { question, answer, _distance } = embedding_result;
+ return { question, answer, _distance }
+ }
+ return null;
+}
+
+/**
+ * Search in given dataset using provided message to get Q/A pair.
+ * This will firstly embedding the message and query use {@link searchByEmbedding}
+ * @param {String} dataset_name The dataset name to be query from
+ * @param {String} message The message to be searched
+ * @returns {Promise} If there's no result, returns null, otherwise returns the result
+ */
+export async function searchByMessage(dataset_name, message) {
+ const { embedding, http_error } = await post('embedding', {body: {
+ content: message
+ }}, { eng: "embedding" });
+
+ return http_error ? null : await searchByEmbedding(dataset_name, embedding);
+}
\ No newline at end of file
diff --git a/database/types.js b/database/types.js
new file mode 100644
index 0000000..54a49be
--- /dev/null
+++ b/database/types.js
@@ -0,0 +1,2 @@
+export const SYSTEM_TABLE = 'system';
+export const DATASET_TABLE = 'dataset';
\ No newline at end of file