|
| 1 | +import os |
| 2 | + |
| 3 | +import matplotlib.pyplot as plt |
| 4 | +import pandas as pd |
| 5 | +import seaborn as sns |
| 6 | +import torch |
| 7 | +from flash import Trainer |
| 8 | +from flash.tabular import TabularClassificationData, TabularClassifier |
| 9 | + |
| 10 | +# %% [markdown] |
| 11 | +# ## 1. Create the DataModule |
| 12 | +# |
| 13 | +# ### Variable & Definition |
| 14 | +# |
| 15 | +# - survival: Survival (0 = No, 1 = Yes) |
| 16 | +# - pclass: Ticket class (1 = 1st, 2 = 2nd, 3 = 3rd) |
| 17 | +# - sex: Sex |
| 18 | +# - Age: Age in years |
| 19 | +# - sibsp: number of siblings / spouses aboard the Titanic |
| 20 | +# - parch: number of parents / children aboard the Titanic |
| 21 | +# - ticket: Ticket number |
| 22 | +# - fare: Passenger fare |
| 23 | +# - cabin: Cabin number |
| 24 | +# - embarked: Port of Embarkation |
| 25 | + |
| 26 | +# %% |
| 27 | +data_path = os.environ.get("PATH_DATASETS", "_datasets") |
| 28 | +path_titanic = os.path.join(data_path, "titanic") |
| 29 | +csv_train = os.path.join(path_titanic, "train.csv") |
| 30 | +csv_test = os.path.join(path_titanic, "test.csv") |
| 31 | + |
| 32 | +df_train = pd.read_csv(csv_train) |
| 33 | +df_train["Survived"].hist(bins=2) |
| 34 | + |
| 35 | +# %% |
| 36 | +datamodule = TabularClassificationData.from_csv( |
| 37 | + categorical_fields=["Sex", "Embarked", "Cabin"], |
| 38 | + numerical_fields=["Fare", "Age", "Pclass", "SibSp", "Parch"], |
| 39 | + target_fields="Survived", |
| 40 | + train_file=csv_train, |
| 41 | + val_split=0.1, |
| 42 | + batch_size=8, |
| 43 | +) |
| 44 | + |
| 45 | +# %% [markdown] |
| 46 | +# ## 2. Build the task |
| 47 | + |
| 48 | +# %% |
| 49 | +model = TabularClassifier.from_data( |
| 50 | + datamodule, |
| 51 | + learning_rate=0.1, |
| 52 | + optimizer="Adam", |
| 53 | + n_a=8, |
| 54 | + gamma=0.3, |
| 55 | +) |
| 56 | + |
| 57 | +# %% [markdown] |
| 58 | +# ## 3. Create the trainer and train the model |
| 59 | + |
| 60 | +# %% |
| 61 | +from pytorch_lightning.loggers import CSVLogger # noqa: E402] |
| 62 | + |
| 63 | +logger = CSVLogger(save_dir="logs/") |
| 64 | +trainer = Trainer( |
| 65 | + max_epochs=10, |
| 66 | + gpus=torch.cuda.device_count(), |
| 67 | + logger=logger, |
| 68 | + accumulate_grad_batches=12, |
| 69 | + gradient_clip_val=0.1, |
| 70 | +) |
| 71 | + |
| 72 | +# %% |
| 73 | + |
| 74 | +trainer.fit(model, datamodule=datamodule) |
| 75 | + |
| 76 | +# %% |
| 77 | + |
| 78 | +metrics = pd.read_csv(f"{trainer.logger.log_dir}/metrics.csv") |
| 79 | +metrics.set_index("step", inplace=True) |
| 80 | +del metrics["epoch"] |
| 81 | +sns.relplot(data=metrics, kind="line") |
| 82 | +plt.gca().set_ylim([0, 1.25]) |
| 83 | +plt.gcf().set_size_inches(10, 5) |
| 84 | + |
| 85 | +# %% [markdown] |
| 86 | +# ## 4. Generate predictions from a CSV |
| 87 | + |
| 88 | +# %% |
| 89 | +df_test = pd.read_csv(csv_test) |
| 90 | + |
| 91 | +predictions = model.predict(csv_test) |
| 92 | +print(predictions[0]) |
| 93 | + |
| 94 | +# %% |
| 95 | +import numpy as np # noqa: E402] |
| 96 | + |
| 97 | +assert len(df_test) == len(predictions) |
| 98 | + |
| 99 | +df_test["Survived"] = np.argmax(predictions, axis=-1) |
| 100 | +df_test.set_index("PassengerId", inplace=True) |
| 101 | +df_test["Survived"].hist(bins=5) |
0 commit comments