class GenomeDataset(Dataset):
def __init__(
self, fasta_file: str, block_size: int, tokenizer: PreTrainedTokenizerFast
):
seq_records = list(SeqIO.parse(fasta_file, "fasta"))
self.tokenized_sequences = []
for s in seq_records:
self.tokenized_sequences.extend(
self.create_token_set_from_record(
s, tokenizer=tokenizer, block_size=block_size
)
)
def create_token_set_from_record(self, s, tokenizer, block_size=512):
sequence = str(s.seq.upper())
sequence = " ".join(sequence[i : i + 3] for i in range(0, len(sequence), 3))
sequence = "[START] " + sequence + " [END]"
out = tokenizer.encode(
sequence, max_length=block_size, return_overflowing_tokens=True
)
if len(out[-1]) != block_size:
padded_last_chunk = list(
np.pad(
out[-1],
(0, block_size - len(out[-1])),
mode="constant",
constant_values=tokenizer.vocab["[PAD]"],
)
)
out = out[:-1]
out.append(padded_last_chunk)
return out
def __len__(self) -> int:
return len(self.tokenized_sequences)
def __getitem__(self, idx: int) -> torch.Tensor:
return torch.tensor(self.tokenized_sequences[idx]) # type:ignore[no-any-return]