Skip to content

Commit ae679fb

Browse files
Add template with flash & titanic (#123)
* move sample * titanic * fix parse name Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 5c1c103 commit ae679fb

File tree

7 files changed

+124
-9
lines changed

7 files changed

+124
-9
lines changed

.actions/assistant.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -622,15 +622,17 @@ def update_env_details(dir_path: str):
622622
req += meta.get("requirements", [])
623623
req = [r.strip() for r in req]
624624

625-
def _parse(pkg: str, keys: str = " <=>[]") -> str:
625+
def _parse_package_name(pkg: str, keys: str = " <=>[]@", egg_name: str = "#egg=") -> str:
626626
"""Parsing just the package name."""
627+
if egg_name in pkg:
628+
pkg = pkg[pkg.index(egg_name) + len(egg_name) :]
627629
if any(c in pkg for c in keys):
628630
ix = min(pkg.index(c) for c in keys if c in pkg)
629631
pkg = pkg[:ix]
630632
return pkg
631633

632-
require = {_parse(r) for r in req if r}
633-
env = {_parse(p): p for p in freeze.freeze()}
634+
require = {_parse_package_name(r) for r in req if r}
635+
env = {_parse_package_name(p): p for p in freeze.freeze()}
634636
meta["environment"] = [env[r] for r in require]
635637
meta["published"] = datetime.now().isoformat()
636638

sample-template/.meta.yml templates/simple/.meta.yml

-6
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,3 @@ requirements:
1111
accelerator:
1212
- CPU
1313
- GPU
14-
datasets:
15-
web:
16-
# starting with http is downloaded
17-
- https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz
18-
kaggle:
19-
- titanic
File renamed without changes.
File renamed without changes.
File renamed without changes.

templates/titanic/.meta.yml

+18
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
title: Solving Titanic dataset with Lightning Flash
2+
author: PL team
3+
created: 2021-10-15
4+
updated: 2021-12-10
5+
license: CC
6+
build: 0
7+
description: |
8+
This is a template to show how to contribute a tutorial.
9+
requirements:
10+
- https://github.com/PyTorchLightning/lightning-flash/archive/refs/tags/0.5.2.zip#egg=lightning-flash[tabular]
11+
- matplotlib
12+
- seaborn
13+
accelerator:
14+
- CPU
15+
- GPU
16+
datasets:
17+
kaggle:
18+
- titanic

templates/titanic/tutorial.py

+101
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
import os
2+
3+
import matplotlib.pyplot as plt
4+
import pandas as pd
5+
import seaborn as sns
6+
import torch
7+
from flash import Trainer
8+
from flash.tabular import TabularClassificationData, TabularClassifier
9+
10+
# %% [markdown]
11+
# ## 1. Create the DataModule
12+
#
13+
# ### Variable & Definition
14+
#
15+
# - survival: Survival (0 = No, 1 = Yes)
16+
# - pclass: Ticket class (1 = 1st, 2 = 2nd, 3 = 3rd)
17+
# - sex: Sex
18+
# - Age: Age in years
19+
# - sibsp: number of siblings / spouses aboard the Titanic
20+
# - parch: number of parents / children aboard the Titanic
21+
# - ticket: Ticket number
22+
# - fare: Passenger fare
23+
# - cabin: Cabin number
24+
# - embarked: Port of Embarkation
25+
26+
# %%
27+
data_path = os.environ.get("PATH_DATASETS", "_datasets")
28+
path_titanic = os.path.join(data_path, "titanic")
29+
csv_train = os.path.join(path_titanic, "train.csv")
30+
csv_test = os.path.join(path_titanic, "test.csv")
31+
32+
df_train = pd.read_csv(csv_train)
33+
df_train["Survived"].hist(bins=2)
34+
35+
# %%
36+
datamodule = TabularClassificationData.from_csv(
37+
categorical_fields=["Sex", "Embarked", "Cabin"],
38+
numerical_fields=["Fare", "Age", "Pclass", "SibSp", "Parch"],
39+
target_fields="Survived",
40+
train_file=csv_train,
41+
val_split=0.1,
42+
batch_size=8,
43+
)
44+
45+
# %% [markdown]
46+
# ## 2. Build the task
47+
48+
# %%
49+
model = TabularClassifier.from_data(
50+
datamodule,
51+
learning_rate=0.1,
52+
optimizer="Adam",
53+
n_a=8,
54+
gamma=0.3,
55+
)
56+
57+
# %% [markdown]
58+
# ## 3. Create the trainer and train the model
59+
60+
# %%
61+
from pytorch_lightning.loggers import CSVLogger # noqa: E402]
62+
63+
logger = CSVLogger(save_dir="logs/")
64+
trainer = Trainer(
65+
max_epochs=10,
66+
gpus=torch.cuda.device_count(),
67+
logger=logger,
68+
accumulate_grad_batches=12,
69+
gradient_clip_val=0.1,
70+
)
71+
72+
# %%
73+
74+
trainer.fit(model, datamodule=datamodule)
75+
76+
# %%
77+
78+
metrics = pd.read_csv(f"{trainer.logger.log_dir}/metrics.csv")
79+
metrics.set_index("step", inplace=True)
80+
del metrics["epoch"]
81+
sns.relplot(data=metrics, kind="line")
82+
plt.gca().set_ylim([0, 1.25])
83+
plt.gcf().set_size_inches(10, 5)
84+
85+
# %% [markdown]
86+
# ## 4. Generate predictions from a CSV
87+
88+
# %%
89+
df_test = pd.read_csv(csv_test)
90+
91+
predictions = model.predict(csv_test)
92+
print(predictions[0])
93+
94+
# %%
95+
import numpy as np # noqa: E402]
96+
97+
assert len(df_test) == len(predictions)
98+
99+
df_test["Survived"] = np.argmax(predictions, axis=-1)
100+
df_test.set_index("PassengerId", inplace=True)
101+
df_test["Survived"].hist(bins=5)

0 commit comments

Comments
 (0)