Skip to content

Commit 55b4076

Browse files
committed
updating linking order
1 parent 9e29c69 commit 55b4076

File tree

7 files changed

+139
-32
lines changed

7 files changed

+139
-32
lines changed

.github/workflows/test.yml

+16-10
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,16 @@ on:
1111
schedule:
1212
- cron: "0 4 * * *"
1313

14+
env:
15+
ENV_FILE: "env.yml"
16+
1417
jobs:
1518
test:
1619
strategy:
1720
fail-fast: false
1821
matrix:
19-
os: [windows-latest] # ubuntu-latest, macos-latest,
20-
python-version: ["3.10", "3.11"] # -> Will re-enable support for py312 once pyg is released
22+
os: [windows-latest] # ubuntu-latest, macos-latest
23+
python-version: ["3.11"] # -> Will re-enable support for py312 once pyg is released, "3.10",
2124

2225
runs-on: ${{ matrix.os }}
2326
timeout-minutes: 30
@@ -27,29 +30,32 @@ jobs:
2730
shell: bash -l {0}
2831

2932
name: |
30-
regular_env -
31-
python=${{ matrix.python-version }} -
32-
os=${{ matrix.os }}
33+
regular_env -
34+
python=${{ matrix.python-version }} -
35+
os=${{ matrix.os }}
3336
3437
steps:
3538
- name: Checkout the code
3639
uses: actions/checkout@v3
40+
41+
- name: Set OS-specific environment file to use
42+
run: |
43+
if [ "${{ matrix.os }}" == "windows-latest" ]; then
44+
echo "ENV_FILE=env_windows.yml" >> $GITHUB_ENV;
45+
fi
3746
3847
- name: Setup mamba
3948
uses: mamba-org/setup-micromamba@v1
4049
with:
41-
environment-file: env.yml
50+
environment-file: ${{ env.ENV_FILE }}
4251
environment-name: graphium
4352
cache-environment: true
4453
cache-downloads: true
4554
create-args: >-
4655
python=${{ matrix.python-version }}
47-
condarc: |
48-
- conda-forge
49-
- pytorch
5056
5157
- name: Install library
52-
run: python -m pip install --no-deps --no-build-isolation -e . # `-e` required for correct `coverage` run.
58+
run: python -m pip install --no-deps --no-build-isolation -e . -v # `-e` required for correct `coverage` run.
5359

5460
- name: Run tests
5561
run: pytest

env.yml

+1-3
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
channels:
22
- conda-forge
3-
- pytorch
4-
# - pyg # Add for Windows
53

64
dependencies:
75
- python >3.9,<3.12
@@ -35,7 +33,7 @@ dependencies:
3533
- lightning >=2.0
3634
- torchmetrics
3735
- ogb
38-
- pytorch_geometric >=2.0 # Use `pyg` for Windows instead of `pytorch_geometric`
36+
- pytorch_geometric >=2.0
3937
- wandb >=0.18.5
4038
- mup
4139
- pytorch_sparse >=0.6

env_windows.yml

+86
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
channels:
2+
- conda-forge
3+
- pytorch
4+
- pyg
5+
6+
dependencies:
7+
- python >3.9,<3.12
8+
- pip
9+
- typer
10+
- loguru
11+
- omegaconf >=2.0.0
12+
- hydra-core >=1.3.2
13+
- tqdm
14+
- platformdirs
15+
16+
# scientific
17+
- numpy <=1.25.0
18+
- scipy >=1.4
19+
- pandas >=1.0
20+
- scikit-learn
21+
- fastparquet
22+
23+
# viz
24+
- matplotlib >=3.0.1
25+
- seaborn
26+
27+
# cloud IO
28+
- fsspec >=2021.6
29+
- s3fs >=2021.6
30+
- gcsfs >=2021.6
31+
32+
# ML packages
33+
- cuda-version # works also with CPU-only system.
34+
- pytorch >=1.12,<2.5
35+
- lightning >=2.0
36+
- torchmetrics
37+
- ogb
38+
- pyg >=2.0
39+
- wandb >=0.18.5
40+
- mup
41+
- pytorch-sparse >=0.6
42+
- pytorch-cluster >=1.5
43+
- pytorch-scatter >=2.0
44+
45+
# chemistry
46+
- rdkit <=2024.03.4
47+
- datamol >=0.10
48+
- boost # needed by rdkit
49+
50+
# Optional deps
51+
- sympy
52+
- tensorboard
53+
- pydantic <2 # because of lightning. See https://github.com/Lightning-AI/lightning/issues/18026 and https://github.com/Lightning-AI/lightning/pull/18022
54+
55+
# Dev
56+
- pytest >=6.0
57+
- pytest-xdist
58+
- pytest-cov
59+
- pytest-forked
60+
- nbconvert
61+
- black >=23
62+
- jupyterlab
63+
- ipywidgets
64+
65+
# Doc
66+
- mkdocs
67+
- mkdocs-material
68+
- mkdocs-material-extensions
69+
- mkdocstrings
70+
- mkdocstrings-python
71+
- mkdocs-jupyter
72+
- markdown-include
73+
- mike >=1.0.0
74+
- doxygen
75+
76+
# graphium_cpp build dependencies
77+
- pybind11
78+
- librdkit-dev
79+
80+
# Optional
81+
- pytdc
82+
83+
- pip:
84+
# Build deps
85+
- setuptools-scm
86+
- build

graphium/graphium_cpp/features.h

+2-2
Original file line numberDiff line numberDiff line change
@@ -195,8 +195,8 @@ template<> struct FeatureValues<double> {
195195
static_assert(std::is_floating_point_v<T>);
196196
return double(inputType);
197197
}
198-
199-
static constexpr bool is_finite(double v) {
198+
199+
static bool is_finite(double v) {
200200
return std::isfinite(v);
201201
}
202202

graphium/graphium_cpp/labels.cpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,14 @@
77

88

99
#include "labels.h"
10-
1110
#include "features.h"
1211

1312
// C++ standard library headers
1413
#include <filesystem>
1514
#include <thread>
1615
#include <unordered_map>
16+
#define NOMINMAX // Disabling MVSC-defined min/max macros
17+
#include <algorithm>
1718

1819
// RDKit headers
1920
#include <GraphMol/ROMol.h>

setup.py

+26-14
Original file line numberDiff line numberDiff line change
@@ -19,22 +19,33 @@
1919

2020
python_version = str(sys.version_info[0]) + str(sys.version_info[1])
2121

22-
torch_dir = torch.__path__[0]
23-
rdkit_lib_index = rdkit.__path__[0].split("/").index("lib")
24-
rdkit_prefix = "/".join(rdkit.__path__[0].split("/")[:rdkit_lib_index])
22+
# Base variables required for compilation
23+
path_separator = "/"
24+
lib_folder_name = "lib"
25+
boost_include = "include/boost"
26+
package_compile_args = []
2527

28+
# Updating variables used during compilation based on OS
2629
system = platform.system()
27-
package_compile_args = [
28-
"-O3",
29-
"-Wall",
30-
"-Wmissing-field-initializers",
31-
"-Wmaybe-uninitialized",
32-
"-Wuninitialized",
33-
]
30+
if system == "Darwin" or system == "Linux":
31+
package_compile_args += ["-O3", "-Wall", "-Wmaybe-uninitialized", "-Wuninitialized"]
32+
3433
if system == "Darwin":
35-
package_compile_args.append("-mmacosx-version-min=10.15")
34+
package_compile_args += ["-mmacosx-version-min=10.15"]
3635
elif system == "Windows":
37-
pass
36+
path_separator = "\\"
37+
lib_folder_name = "Lib"
38+
package_compile_args += ["/Wall", "/O3"]
39+
40+
# Extracting paths to torch and rdkit dependencies
41+
torch_dir = torch.__path__[0]
42+
rdkit_lib_index = rdkit.__path__[0].split(path_separator).index(lib_folder_name) # breaks on windows
43+
rdkit_prefix = "/".join(rdkit.__path__[0].split(path_separator)[:rdkit_lib_index])
44+
45+
# Windows-specific changed to rdkit path
46+
if system == "Windows":
47+
rdkit_prefix += "/Library"
48+
boost_include = "include"
3849

3950
ext_modules = [
4051
Pybind11Extension(
@@ -57,10 +68,11 @@
5768
os.path.join(torch_dir, "include"),
5869
os.path.join(torch_dir, "include/torch/csrc/api/include"),
5970
os.path.join(rdkit_prefix, "include/rdkit"),
60-
os.path.join(rdkit_prefix, "include/boost"),
71+
os.path.join(rdkit_prefix, boost_include),
6172
numpy.get_include(),
6273
],
6374
libraries=[
75+
"RDKitRDGeneral",
6476
"RDKitAlignment",
6577
"RDKitDataStructs",
6678
"RDKitDistGeometry",
@@ -73,13 +85,13 @@
7385
"RDKitInchi",
7486
"RDKitRDInchiLib",
7587
"RDKitRDBoost",
76-
"RDKitRDGeneral",
7788
"RDKitRDGeometryLib",
7889
"RDKitRingDecomposerLib",
7990
"RDKitSmilesParse",
8091
"RDKitSubstructMatch",
8192
"torch_cpu",
8293
"torch_python",
94+
"c10",
8395
f"boost_python{python_version}",
8496
],
8597
library_dirs=[os.path.join(rdkit_prefix, "lib"), os.path.join(torch_dir, "lib")],

tests/test_datamodule.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,9 @@ def test_ogb_datamodule(self):
8080
# test module
8181
assert ds.num_edge_feats == 5
8282
assert ds.num_node_feats == 50
83-
assert len(ds) == 642
83+
assert (
84+
len(ds) == 642 or len(ds) == 644
85+
) # Accounting for differences in csv file reads across Linux & OSX
8486

8587
# test batch loader
8688
batch = next(iter(ds.train_dataloader()))
@@ -183,7 +185,9 @@ def test_caching(self):
183185
# test module
184186
assert ds.num_edge_feats == 5
185187
assert ds.num_node_feats == 50
186-
assert len(ds) == 642
188+
assert (
189+
len(ds) == 642 or len(ds) == 644
190+
) # Accounting for differences in csv file reads across Linux & OSX
187191

188192
# test batch loader
189193
batch = next(iter(ds.train_dataloader()))

0 commit comments

Comments
 (0)