-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #4 from sgrvinod/0.2.0
0.2.0
- Loading branch information
Showing
98 changed files
with
219,289 additions
and
145,195 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,7 +2,7 @@ | |
*.pyc | ||
*.egg-info | ||
|
||
logs | ||
.vscode | ||
__pycache__ | ||
chess_transformers/checkpoints | ||
checkpoints | ||
logs |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
# Change Log | ||
|
||
## Unreleased (v0.2.0) | ||
|
||
### Added | ||
|
||
* **`ChessTransformerEncoderFT`** is an encoder-only transformer that predicts source (*From*) and destination squares (*To*) squares for the next half-move, instead of the half-move in UCI notation. | ||
* [*CT-EFT-20*](https://github.com/sgrvinod/chess-transformers#ct-eft-20) is a new trained model of this type with about 20 million parameters. | ||
* **`ChessDatasetFT`** is a PyTorch dataset class for this model type. | ||
* [**`chess_transformer.data.levels`**](https://github.com/sgrvinod/chess-transformers/blob/main/chess_transformers/data/levels.py) provides a standardized vocabulary (with indices) for oft-used categorical variables. All models and datasets will hereon use this standard vocabulary instead of a dataset-specific vocabulary. | ||
|
||
### Changed | ||
|
||
* The [*LE1222*](https://github.com/sgrvinod/chess-transformers#le1222) and [*LE1222x*](https://github.com/sgrvinod/chess-transformers#le1222x) datasets no longer have their own vocabularies or vocabulary files. Instead, they use the standard vocabulary from **`chess_transformer.data.levels`**. | ||
* The [*LE1222*](https://github.com/sgrvinod/chess-transformers#le1222) and [*LE1222x*](https://github.com/sgrvinod/chess-transformers#le1222x) datasets have been re-encoded with indices corresponding to the standard vocabulary. Earlier versions or downloads of these datasets are no longer valid for use with this library. | ||
* The row index at which the validation split begins in each dataset is now stored as an attribute of the **`encoded_data`** table in the corresponding H5 file, instead of in a separate JSON file. | ||
* Models [*CT-E-20*](https://github.com/sgrvinod/chess-transformers#ct-e-20) and [*CT-ED-45*](https://github.com/sgrvinod/chess-transformers#ct-ed-45) already trained with a non-standard, dataset-specific vocabulary have been refactored for use with the standard vocabulary. Earlier versions or downloads of these models are no longer valid for use with this library. | ||
* The field **`move_sequence`** in the H5 tables has now been renamed to **`moves`**. | ||
* The field **`move_sequence_length`** in the H5 tables has now been renamed to **`length`**. | ||
* The **`load_assets()`** function has been renamed to **`load_model()`** and it no longer returns a vocabulary — only the model. | ||
* The **`chess_transformers/eval`** folder has been renamed to [**`chess_transformers/evaluate`**](https://github.com/sgrvinod/chess-transformers/tree/main/chess_transformers/evaluate). | ||
* The Python notebook **`lichess_eval.ipynb`** has been converted to a Python script [**`evaluate.py`**](https://github.com/sgrvinod/chess-transformers/blob/main/chess_transformers/evaluate/evaluation.py), which runs much faster for evaluation. | ||
* Fairy Stockfish is now run on 8 threads and with a hash table of size 8 GB during evaluation instead of 1 thread and 16 MB respectively, which makes it a more challenging opponent. | ||
* Evaluation results have been recomputed for [*CT-E-20*](https://github.com/sgrvinod/chess-transformers#ct-e-20) and [*CT-ED-45*](https://github.com/sgrvinod/chess-transformers#ct-ed-45) against this stronger Fairy Stockfish — they naturally fare worse. | ||
|
||
### Removed | ||
|
||
* The environment variable **`CT_LOGS_FOLDER`** no longer needs to be set before training a model. Training logs will now always be saved to **`chess_transformers/training/logs`**. | ||
* The environment variable **`CT_CHECKPOINT_FOLDER`** no longer needs to be set before training a model. Checkpoints will now always be saved to **`chess_transformers/checkpoints`**. | ||
* The environment variable **`CT_EVAL_GAMES_FOLDER`** no longer needs to be set before evaluating a model. Evaluation games will now always be saved to **`chess_transformers/evaluate/games`**. | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,6 +3,6 @@ | |
"configs", | ||
"data", | ||
"train", | ||
"eval", | ||
"evaluate", | ||
"play", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,113 @@ | ||
import torch | ||
import pathlib | ||
|
||
from chess_transformers.train.utils import get_lr | ||
from chess_transformers.configs.data.LE1222 import * | ||
from chess_transformers.configs.other.stockfish import * | ||
from chess_transformers.train.datasets import ChessDatasetFT | ||
from chess_transformers.configs.other.fairy_stockfish import * | ||
from chess_transformers.transformers.criteria import LabelSmoothedCE | ||
from chess_transformers.data.levels import TURN, PIECES, UCI_MOVES, BOOL | ||
from chess_transformers.transformers.models import ChessTransformerEncoderFT | ||
|
||
|
||
############################### | ||
############ Name ############# | ||
############################### | ||
|
||
NAME = "CT-EFT-20" # name and identifier for this configuration | ||
|
||
############################### | ||
######### Dataloading ######### | ||
############################### | ||
|
||
DATASET = ChessDatasetFT # custom PyTorch dataset | ||
BATCH_SIZE = 512 # batch size | ||
NUM_WORKERS = 8 # number of workers to use for dataloading | ||
PREFETCH_FACTOR = 2 # number of batches to prefetch per worker | ||
PIN_MEMORY = False # pin to GPU memory when dataloading? | ||
|
||
############################### | ||
############ Model ############ | ||
############################### | ||
|
||
VOCAB_SIZES = { | ||
"moves": len(UCI_MOVES), | ||
"turn": len(TURN), | ||
"white_kingside_castling_rights": len(BOOL), | ||
"white_queenside_castling_rights": len(BOOL), | ||
"black_kingside_castling_rights": len(BOOL), | ||
"black_queenside_castling_rights": len(BOOL), | ||
"board_position": len(PIECES), | ||
} # vocabulary sizes | ||
D_MODEL = 512 # size of vectors throughout the transformer model | ||
N_HEADS = 8 # number of heads in the multi-head attention | ||
D_QUERIES = 64 # size of query vectors (and also the size of the key vectors) in the multi-head attention | ||
D_VALUES = 64 # size of value vectors in the multi-head attention | ||
D_INNER = 2048 # an intermediate size in the position-wise FC | ||
N_LAYERS = 6 # number of layers in the Encoder and Decoder | ||
DROPOUT = 0.1 # dropout probability | ||
N_MOVES = 1 # expected maximum length of move sequences in the model, <= MAX_MOVE_SEQUENCE_LENGTH | ||
DISABLE_COMPILATION = False # disable model compilation? | ||
COMPILATION_MODE = "default" # mode of model compilation (see torch.compile()) | ||
DYNAMIC_COMPILATION = True # expect tensors with dynamic shapes? | ||
SAMPLING_K = 1 # k in top-k sampling model predictions during play | ||
MODEL = ChessTransformerEncoderFT # custom PyTorch model to train | ||
|
||
############################### | ||
########### Training ########## | ||
############################### | ||
|
||
BATCHES_PER_STEP = ( | ||
4 # perform a training step, i.e. update parameters, once every so many batches | ||
) | ||
PRINT_FREQUENCY = 1 # print status once every so many steps | ||
N_STEPS = 100000 # number of training steps | ||
WARMUP_STEPS = 8000 # number of warmup steps where learning rate is increased linearly; twice the value in the paper, as in the official transformer repo. | ||
STEP = 1 # the step number, start from 1 to prevent math error in the next line | ||
LR = get_lr( | ||
step=STEP, d_model=D_MODEL, warmup_steps=WARMUP_STEPS | ||
) # see utils.py for learning rate schedule; twice the schedule in the paper, as in the official transformer repo. | ||
START_EPOCH = 0 # start at this epoch | ||
BETAS = (0.9, 0.98) # beta coefficients in the Adam optimizer | ||
EPSILON = 1e-9 # epsilon term in the Adam optimizer | ||
LABEL_SMOOTHING = 0.1 # label smoothing co-efficient in the Cross Entropy loss | ||
BOARD_STATUS_LENGTH = 70 # total length of input sequence | ||
USE_AMP = True # use automatic mixed precision training? | ||
CRITERION = LabelSmoothedCE # training criterion (loss) | ||
OPTIMIZER = torch.optim.Adam # optimizer | ||
LOGS_FOLDER = str( | ||
pathlib.Path(__file__).parent.parent.parent.resolve() / "train" / "logs" / NAME | ||
) # logs folder | ||
|
||
############################### | ||
######### Checkpoints ######### | ||
############################### | ||
|
||
CHECKPOINT_FOLDER = str( | ||
pathlib.Path(__file__).parent.parent.parent.resolve() / "checkpoints" / NAME | ||
) # folder containing checkpoints | ||
TRAINING_CHECKPOINT = ( | ||
NAME + ".pt" | ||
) # path to model checkpoint to resume training, None if none | ||
CHECKPOINT_AVG_PREFIX = ( | ||
"step" # prefix to add to checkpoint name when saving checkpoints for averaging | ||
) | ||
CHECKPOINT_AVG_SUFFIX = ( | ||
".pt" # checkpoint end string to match checkpoints saved for averaging | ||
) | ||
FINAL_CHECKPOINT = ( | ||
"averaged_" + NAME + ".pt" | ||
) # final checkpoint to be used for eval/inference | ||
FINAL_CHECKPOINT_GDID = ( | ||
"1OHtg336ujlOjp5Kp0KjE1fAPF74aZpZD" # Google Drive ID for download | ||
) | ||
|
||
|
||
################################ | ||
########## Evaluation ########## | ||
################################ | ||
|
||
EVAL_GAMES_FOLDER = str( | ||
pathlib.Path(__file__).parent.parent.parent.resolve() / "eval" / "games" / NAME | ||
) # folder where evaluation games are saved in PGN files |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1 @@ | ||
__all__ = ["CT-E-19", "CT-ED-45"] | ||
__all__ = ["CT-E-19", "CT-ED-45", "CT-EFT-20"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,3 @@ | ||
__all__ = ["prep", "utils"] | ||
__all__ = ["prep", "utils", "levels"] | ||
|
||
from chess_transformers.data.prep import prepare_data |
Oops, something went wrong.