diff --git a/brainscore_core/submission/endpoints.py b/brainscore_core/submission/endpoints.py index 440c228b..676cdedd 100644 --- a/brainscore_core/submission/endpoints.py +++ b/brainscore_core/submission/endpoints.py @@ -4,6 +4,7 @@ import logging import os import json +import pickle import random import smtplib import string @@ -204,9 +205,37 @@ class RunScoringEndpoint: def __init__(self, domain_plugins: DomainPlugins, db_secret: str): self.domain_plugins = domain_plugins + self._db_secret = db_secret logger.info(f"Connecting to db using secret '{db_secret}'") connect_db(db_secret=db_secret) + def _is_production(self) -> bool: + return 'sqlite3' not in self._db_secret + + def _upload_score_to_s3(self, score_result: Score, model_identifier: str, + benchmark_identifier: str, domain: str) -> None: + if not self._is_production(): + return + + s3_bucket = 'brainscore-storage' + s3_key = f'brainscore-{domain}/score-obj/{model_identifier}__{benchmark_identifier}.pkl' + + try: + import boto3 + import tempfile + + with tempfile.NamedTemporaryFile(suffix='.pkl', delete=True) as tmp: + pickle.dump(score_result, tmp, protocol=pickle.HIGHEST_PROTOCOL) + tmp.flush() + tmp.seek(0) + s3_client = boto3.client('s3') + s3_client.upload_file(tmp.name, s3_bucket, s3_key) + + logger.info(f'Uploaded score object to s3://{s3_bucket}/{s3_key}') + except Exception as e: + logger.warning(f'Failed to upload score object to S3 for ' + f'{model_identifier} on {benchmark_identifier}: {e}') + def __call__(self, domain: str, jenkins_id: int, model_identifier: str, benchmark_identifier: str, user_id: int, model_type: str, public: bool, competition: Union[None, str]): """ @@ -278,6 +307,7 @@ def _score_model_on_benchmark(self, model_identifier: str, benchmark_identifier: # store in database logger.info(f'Score from running {model_identifier} on {benchmark_identifier}: {score_result}') update_score(score_result, score_entry) + self._upload_score_to_s3(score_result, model_identifier, benchmark_identifier, domain) except Exception as e: stacktrace = traceback.format_exc() error_message = f'Model {model_identifier} could not run on benchmark {benchmark_identifier}: ' \