diff --git a/api_routes.py b/api_routes.py index 22fa45d..db64df6 100644 --- a/api_routes.py +++ b/api_routes.py @@ -1,6 +1,8 @@ import os import subprocess +from collections import defaultdict from datetime import datetime +from pathlib import Path import langcheck import pytz @@ -84,6 +86,38 @@ def logs(): return jsonify(logs=db.get_chatlogs_and_metrics(per_page, offset)) +@api_routes_blueprint.route('/api/logs_comparison', methods=['GET']) +def logs_comparison(): + page = int(request.args.get('page', 1)) + database_a_name = request.args.get('database_a') + database_b_name = request.args.get('database_b') + assert database_a_name is not None + assert database_b_name is not None + database_a_path = Path('db/' + database_a_name) + database_b_path = Path('db/' + database_b_name) + + errors = defaultdict(list) + if not database_a_path.exists(): + errors['database-a'].append( + f'{database_a_name} does not exist in the db/ directory') + elif not database_a_path.is_file(): + errors['database-a'].append(f'{database_a_name} is not a file') + if not database_b_path.exists(): + errors['database-b'].append( + f'{database_b_name} does not exist in the db/ directory') + elif not database_b_path.is_file(): + errors['database-b'].append(f'{database_b_name} is not a file') + if len(errors) > 0: + return {'success': False, 'errors': errors} + + per_page = 10 + offset = (page - 1) * per_page + return jsonify(success=True, + logs=db.get_comparison_chatlogs_and_metrics( + str(database_a_path), str(database_b_path), per_page, + offset)) + + @api_routes_blueprint.route('/api/metrics/', methods=['GET']) def metrics_endpoint(log_id): metrics_data = db.get_metrics_by_log_id(log_id) diff --git a/database.py b/database.py index b49664b..88a9e43 100644 --- a/database.py +++ b/database.py @@ -4,13 +4,13 @@ DATABASE_URL = 'db/langcheckchat.db' -def initialize_db(): +def initialize_db(database_url: str = DATABASE_URL): with open('db/chat_log_schema.sql', 'r') as file: chat_log_schema_script = file.read() with open('db/metric_schema.sql', 'r') as file: metric_schema_script = file.read() - with sqlite3.connect(DATABASE_URL) as conn: + with sqlite3.connect(database_url) as conn: cursor = conn.cursor() cursor.executescript(chat_log_schema_script) cursor.executescript(metric_schema_script) @@ -18,27 +18,29 @@ def initialize_db(): def _select_data(query: str, - params: Optional[Dict[str, Any]] = None) -> List[sqlite3.Row]: + params: Optional[Dict[str, Any]] = None, + database_url: str = DATABASE_URL) -> List[sqlite3.Row]: '''Runs a SQL SELECT query on the SQLite database. ''' if params is None: params = {} - with sqlite3.connect(DATABASE_URL) as conn: + with sqlite3.connect(database_url) as conn: conn.row_factory = sqlite3.Row cursor = conn.cursor() return cursor.execute(query, params).fetchall() def _edit_data(query: str, - params: Optional[List[Any]] = None) -> Optional[int]: + params: Optional[List[Any]] = None, + database_url: str = DATABASE_URL) -> Optional[int]: '''Runs a SQL INSERT or UPDATE query on the SQLite database. For a INSERT query, it returns the last inserted row id (lastrowid). ''' if params is None: params = [] - with sqlite3.connect(DATABASE_URL) as conn: + with sqlite3.connect(database_url) as conn: conn.row_factory = sqlite3.Row cursor = conn.cursor() cursor.execute(query, params) @@ -109,6 +111,107 @@ def get_chatlogs_and_metrics(limit: int, offset: int) -> List[dict]: return list(id_to_logs.values()) +def get_comparison_chatlogs_and_metrics(database_a_url: str, + database_b_url: str, limit: int, + offset: int) -> List[dict]: + ''' + Returns a list of chat logs and metrics for Database A and Database B, each + of which is a dictionary with the following structure: + { + "": { + "request_a": "...", + "response_a": "...", + "response_b": "...", + "reference_a": "...", + "timestamp_a": "", + "source_a": "..", + "source_b": "..", + "language_a": "", + "status_a": "done", + "metrics_a": { + "ai_disclaimer_similarity": {"metric_value": , "explanation": "..."}, + "factual_consistency_openai": {"metric_value": , "explanation": "..."}, + ... + }, + "metrics_b": { + "ai_disclaimer_similarity": {"metric_value": , "explanation": "..."}, + "factual_consistency_openai": {"metric_value": , "explanation": "..."}, + ... + } + } + } + ''' + query_a = ''' + SELECT chat_log.*, metric.metric_name, metric.metric_value, metric.explanation + FROM ( + SELECT * FROM chat_log + ORDER BY timestamp DESC + LIMIT :limit OFFSET :offset + ) AS chat_log + LEFT JOIN metric ON chat_log.id = metric.log_id + ''' + a_logs = _select_data(query_a, + params={ + 'limit': limit, + 'offset': offset + }, + database_url=database_a_url) + query_b = ''' + SELECT chat_log.*, metric.metric_name, metric.metric_value, metric.explanation + FROM ( + SELECT * FROM chat_log + ORDER BY timestamp DESC + ) AS chat_log + LEFT JOIN metric ON chat_log.id = metric.log_id + ''' + b_logs = _select_data(query_b, database_url=database_b_url) + metric_columns = ['metric_name', 'metric_value', 'explanation'] + + # Each row in a_logs corresponds to a single metric. We want to group + # together all the metrics for a single chat log. + id_to_logs = {} + request_a_to_id = {} + for log in a_logs: + id = log['id'] + if id not in id_to_logs: + # Append '_a' to the keys to distinguish them from the keys in + # b_logs + chat_log = { + f'{k}_a': log[k] + for k in log.keys() if k not in metric_columns + } + id_to_logs[id] = chat_log + id_to_logs[id]['metrics_a'] = {} + id_to_logs[id]['metrics_b'] = {} + # Store the mapping from request to id + request_a_to_id[log['request']] = id + id_to_logs[id]['metrics_a'][log['metric_name']] = { + 'metric_value': log['metric_value'], + 'explanation': log['explanation'] + } + + for log in b_logs: + request_b = log['request'] + # Ignore this log if the request does not match any of the requests in + # a_logs + if request_b not in request_a_to_id: + continue + a_id = request_a_to_id[request_b] + + # Add response_b and source_b to the logs. Note that these may already + # have been added (since each row in b_logs corresponds to a single + # metric), but they should be the same so it doesn't matter. + id_to_logs[a_id]['response_b'] = log['response'] + id_to_logs[a_id]['source_b'] = log['source'] + + # Add the metrics from b_logs to the logs + id_to_logs[a_id]['metrics_b'][log['metric_name']] = { + 'metric_value': log['metric_value'], + 'explanation': log['explanation'] + } + return list(id_to_logs.values()) + + def insert_chatlog(data: Dict[str, Any]) -> int: columns = ', '.join(data.keys()) placeholders = ', '.join(['?' for _ in data.keys()]) diff --git a/db/evaluation_results_a.db b/db/evaluation_results_a.db new file mode 100644 index 0000000..2066b9a Binary files /dev/null and b/db/evaluation_results_a.db differ diff --git a/db/evaluation_results_b.db b/db/evaluation_results_b.db new file mode 100644 index 0000000..db3932a Binary files /dev/null and b/db/evaluation_results_b.db differ diff --git a/static/logs_comparison.html b/static/logs_comparison.html new file mode 100644 index 0000000..951e145 --- /dev/null +++ b/static/logs_comparison.html @@ -0,0 +1,84 @@ + + + + + + + Q&A Logs + + + + + + + + + + +
+

Comparison Logs

+ +
+
+
+
+
Database A
+
+ +
+
+
+
+
+
+
Database B
+
+ +
+
+
+
+ +
+
+ + + + + + + + + + + + + + + + + +
User MessageBot Message ABot Message BReferenceMetric AMetric BSource ASource B
+ +
+ + + +
+ + +
+ + + + + + + + + \ No newline at end of file diff --git a/static/logs_comparison.js b/static/logs_comparison.js new file mode 100644 index 0000000..4007bb2 --- /dev/null +++ b/static/logs_comparison.js @@ -0,0 +1,126 @@ +let currentPage = 1; + +function loadLogs(direction) { + if (direction === 'next') { + currentPage += 1; + } else if (direction === 'prev' && currentPage > 1) { + currentPage -= 1; + } + $('#qa-table tr:not(:first)').remove(); // Remove all rows except headers + const logsComparisonUrl = `/api/logs_comparison?page=${currentPage}&database_a=${$('#database-a').val()}&database_b=${$('#database-b').val()}`; + $.get(logsComparisonUrl, function(response) { + if (response.success === false) { + const errors = Object.keys(response.errors); + errors.forEach(function(error) { + $(`*[data-error-for~="${error}"]`).addClass('is-invalid'); + $(`*[data-error-message-for="${error}"]`).html(`

${response.errors[error]}

`); + }); + return; + } + response.logs.forEach(log => { + // Construct the rows of the metrics table. `log` has the fields + // `metrics_a` and `metrics_b`, which are JSON objects with the + // metric names as keys and their values as + // {'metric_value': value, 'explanation': explanation}. We iterate + // over this object and construct the rows of the table. + function constructMetricRows(metrics) { + return Object.entries(metrics).map(([metricName, metricData]) => { + if (metricData.explanation !== null) { + const title = escapeHTML(metricData.explanation); + return ` + ${metricName} + + + + + ${round(metricData.metric_value, 4)} + `; + } else { + return `${metricName}${round(metricData.metric_value, 4)}`; + } + }); + } + const metricRowsA = constructMetricRows(log.metrics_a); + const metricRowsB = constructMetricRows(log.metrics_b); + + // Construct the metrics table + function constructMetricsTable(metricRows) { + return ` + + + + + + + + ${metricRows.join('')} + +
MetricValue
`; + } + + $('#qa-table').append( + ` + ${log.request_a} + ${log.response_a} + ${log.response_b} + ${log.reference_a == null ? '' : log.reference_a} + + ${constructMetricsTable(metricRowsA)} + + + ${constructMetricsTable(metricRowsB)} + + +
+ ${log.source_a.substring(0, 300)}... + Show +
+
${log.source_a}
+ + +
+ ${log.source_b.substring(0, 300)}... + Show +
+
${log.source_b}
+ + ` + ); + }); + $('#pageIndicator').text(currentPage); + feather.replace(); + $('[data-toggle="tooltip"]').tooltip(); + }, 'json'); +} + +$(document).ready(function() { + loadLogs(); // Load initial logs on page load + $('#prevButton').click(function() { loadLogs('prev'); }); + $('#nextButton').click(function() { loadLogs('next'); }); + $('#database-names-form').submit(function(e) { + e.preventDefault(); // Prevent the form from being submitted normally + // Clear the error messages + $('*[data-error-for]').removeClass('is-invalid'); + $('*[data-error-message-for]').empty(); + loadLogs(); + }); +}); + +$('body').on('click', '.show-source', showSource); +function showSource(e) { + var link = $(e.currentTarget); + var input_preview = link.prev(); + var source = link.parent().next(); + + if (link.text() === 'Show ') { + input_preview.css('visibility', 'hidden'); + input_preview.css('height', '0px'); + source.show(); + link.html('Hide ' + feather.icons['minimize-2'].toSvg()); + } else { + input_preview.css('visibility', 'visible'); + source.hide(); + link.html('Show ' + feather.icons['maximize-2'].toSvg()); + } + e.preventDefault(); +} \ No newline at end of file diff --git a/static/style.css b/static/style.css index 97c487a..80dffaf 100644 --- a/static/style.css +++ b/static/style.css @@ -140,4 +140,9 @@ summary { font-size: 1.15rem; font-weight: bold; color: #495057; +} + +.database-name-error-message { + padding-left: 125px; /* Hack to align the error message with the input box */ + box-sizing: border-box; } \ No newline at end of file