diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..1800114 --- /dev/null +++ b/.gitignore @@ -0,0 +1,174 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# UV +# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +#uv.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/latest/usage/project/#working-with-version-control +.pdm.toml +.pdm-python +.pdm-build/ + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ + +# Ruff stuff: +.ruff_cache/ + +# PyPI configuration file +.pypirc \ No newline at end of file diff --git a/README.md b/README.md index 4fcb44e..ef86686 100644 --- a/README.md +++ b/README.md @@ -94,6 +94,11 @@ from sccello.src.model_prototype_contrastive import PrototypeContrastiveForMaske model = PrototypeContrastiveForMaskedLM.from_pretrained("katarinayuan/scCello-zeroshot", output_hidden_states=True) ``` +or run to ensure the model can load properly +``` +python ./sccello/script/run_load_model.py +``` + * for linear probing tasks (see details in sccello/script/run_cell_type_classification.py) ``` from sccello.src.model_prototype_contrastive import PrototypeContrastiveForSequenceClassification diff --git a/sccello/script/run_cell_type_clustering.py b/sccello/script/run_cell_type_clustering.py index 8eba8ea..94c8847 100644 --- a/sccello/script/run_cell_type_clustering.py +++ b/sccello/script/run_cell_type_clustering.py @@ -36,6 +36,8 @@ def parse_args(): parser.add_argument("--model_source", type=str, default="model_prototype_contrastive") parser.add_argument("--indist", type=int, default=0) + parser.add_argument('--sample_outdist', action='store_true', help='Samples the out of distribution') + parser.add_argument('--num_samples', type=int, default=10, help='If sample_outdist is set, this sets the number of samples') parser.add_argument("--normalize", type=int, default=0) parser.add_argument("--pass_cell_cls", type=int, default=0) @@ -59,7 +61,7 @@ def parse_args(): return args def solve_clustering(args, all_datasets): - trainset, test_data1, test_data2, label_dict = all_datasets + _, test_data1, test_data2, _ = all_datasets args.output_dir = helpers.create_downstream_output_dir(args) @@ -119,10 +121,17 @@ def solve_clustering(args, all_datasets): names = CellTypeClassificationDataset.subsets["frac"] if args.indist: names = [names[0]] + + if args.sample_outdist: + names = [names[1]] for name in names: # every data is tested under the same seeded setting helpers.set_seed(args.seed) args.data_source = f"frac_{name}" - all_datasets = data_loading.get_fracdata(name, args.data_branch, args.indist, False) - solve_clustering(args, all_datasets) \ No newline at end of file + all_datasets = ( + data_loading.get_fracdata_sample(name, num_samples=args.num_samples) + if args.sample_outdist else + data_loading.get_fracdata(name, args.data_branch, args.indist, False) + ) + solve_clustering(args, all_datasets) diff --git a/sccello/script/run_load_model.py b/sccello/script/run_load_model.py new file mode 100644 index 0000000..c6a2bce --- /dev/null +++ b/sccello/script/run_load_model.py @@ -0,0 +1,8 @@ +import os +import sys +EXC_DIR = os.path.dirname(os.path.dirname(os.path.dirname(__file__))) +sys.path.append(EXC_DIR) + +from sccello.src.model_prototype_contrastive import PrototypeContrastiveForMaskedLM + +model = PrototypeContrastiveForMaskedLM.from_pretrained("katarinayuan/scCello-zeroshot", output_hidden_states=True) diff --git a/sccello/script/run_novel_cell_type_classification.py b/sccello/script/run_novel_cell_type_classification.py index c360ed5..87ea84d 100644 --- a/sccello/script/run_novel_cell_type_classification.py +++ b/sccello/script/run_novel_cell_type_classification.py @@ -65,7 +65,13 @@ def get_cell_type_labelid2nodeid(cell_type_idmap, clid2nodeid): # e.g., {'placental pericyte': CL:2000078, ...} name2clid = {v.lower(): k for k, v in clid2name.items()} - cell_type2nodeid = dict([(k, clid2nodeid[name2clid[cell_type2name[k]]]) if cell_type2name[k] in name2clid else (k, -1) for k in cell_type2name]) + cell_type2nodeid = dict( + [ + (k, clid2nodeid[name2clid[cell_type2name[k]]]) + if cell_type2name[k] in name2clid else (k, -1) + for k in cell_type2name + ] + ) return cell_type2nodeid def load_cell_type_representation(args, model): diff --git a/sccello/src/data/dataset.py b/sccello/src/data/dataset.py index 4089dee..32b91d0 100644 --- a/sccello/src/data/dataset.py +++ b/sccello/src/data/dataset.py @@ -32,8 +32,8 @@ class CellTypeClassificationDataset(): @classmethod def create_dataset(cls, subset_name="celltype"): assert subset_name in cls.subsets["frac"] - valid_data = load_dataset(f"katarinayuan/scCello_ood_{subset_name}_data1")["train"] - test_data = load_dataset(f"katarinayuan/scCello_ood_{subset_name}_data2")["train"] + valid_data = load_dataset(f"katarinayuan/scCello_ood_{subset_name}_data1", split="train") + test_data = load_dataset(f"katarinayuan/scCello_ood_{subset_name}_data2", split="train") valid_data = valid_data.rename_column("cell_type", "label") test_data = test_data.rename_column("cell_type", "label") diff --git a/sccello/src/utils/data_loading.py b/sccello/src/utils/data_loading.py index a9a9709..e0913df 100644 --- a/sccello/src/utils/data_loading.py +++ b/sccello/src/utils/data_loading.py @@ -112,6 +112,20 @@ def get_prestored_data(data_file_name): else: raise NotImplementedError +def get_fracdata_sample(name, num_proc=12, num_samples=10): + + from sccello.src.data.dataset import CellTypeClassificationDataset + data1, data2 = CellTypeClassificationDataset.create_dataset(name) + data1 = data1.rename_column("gene_token_ids", "input_ids") + data2 = data2.rename_column("gene_token_ids", "input_ids") + + data1, data2 = data1.select(range(num_samples)), data2.select(range(num_samples)) + + data1, eval_label_type_idmap = helpers.process_label_type(data1, num_proc, "label") + data2, test_label_type_idmap = helpers.process_label_type(data2, num_proc, "label") + + return None, data1, data2, None + def get_fracdata(name, data_branch, indist, batch_effect, num_proc=12): from sccello.src.data.dataset import CellTypeClassificationDataset diff --git a/sccello/src/utils/helpers.py b/sccello/src/utils/helpers.py index 539f784..54aaa68 100644 --- a/sccello/src/utils/helpers.py +++ b/sccello/src/utils/helpers.py @@ -11,7 +11,6 @@ EXC_DIR = os.path.dirname(os.path.dirname(os.path.dirname(__file__))) from sccello.src.utils import logging_util -from sccello.src.model_prototype_contrastive import PrototypeContrastiveForMaskedLM def set_seed(seed): @@ -77,6 +76,8 @@ def create_downstream_output_dir(args): def load_model_inference(args): + from sccello.src.model_prototype_contrastive import PrototypeContrastiveForMaskedLM + model = eval(args.model_class).from_pretrained(args.pretrained_ckpt, output_hidden_states=True).to("cuda") for param in model.bert.parameters(): param.requires_grad = False