Skip to content

Commit 8cc2b01

Browse files
committed
changing imports
1 parent 3226d1a commit 8cc2b01

File tree

5 files changed

+48
-13
lines changed

5 files changed

+48
-13
lines changed

gbrl/__init__.py

Lines changed: 40 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,46 @@
88
##############################################################################
99
__version__ = "1.0.7"
1010

11-
from .ac_gbrl import (ActorCritic, GaussianActor, ContinuousCritic,
12-
DiscreteCritic, ParametricActor)
13-
from .gbt import GBRL
14-
from .gbrl_cpp import GBRL as GBRL_CPP
11+
import importlib.util
12+
import os
13+
import platform
14+
15+
_loaded_cpp_module = None
16+
17+
def load_cpp_module():
18+
global _loaded_cpp_module
19+
module_name = "gbrl_cpp"
20+
if platform.system() == "Windows":
21+
ext = ".pyd"
22+
elif platform.system() == "Darwin": # macOS
23+
ext = ".dylib"
24+
else: # Assume Linux/Unix
25+
ext = ".so"
26+
possible_paths = [
27+
os.path.join(os.path.dirname(__file__)), # Current directory
28+
os.path.join(os.path.dirname(__file__), "Release"), # Release folder
29+
]
30+
for dir_path in possible_paths:
31+
if os.path.exists(dir_path):
32+
# Scan for files that match the module name and extension
33+
for file_name in os.listdir(dir_path):
34+
if file_name.startswith(module_name) and file_name.endswith(ext):
35+
# Dynamically load the matching shared library
36+
file_path = os.path.join(dir_path, file_name)
37+
spec = importlib.util.spec_from_file_location(module_name, file_path)
38+
module = importlib.util.module_from_spec(spec)
39+
spec.loader.exec_module(module)
40+
_loaded_cpp_module = module = module
41+
return module
42+
43+
raise ImportError(f"Could not find {module_name}{ext} in any of the expected locations: {possible_paths}")
44+
45+
46+
# Load the C++ module dynamically
47+
_gbrl_cpp_module = load_cpp_module()
48+
49+
# Create a global alias for the GBRL class
50+
GBRL_CPP = _gbrl_cpp_module.GBRL
1551

1652
cuda_available = GBRL_CPP.cuda_available
1753

gbrl/ac_gbrl.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,9 @@
1111
import numpy as np
1212
import torch as th
1313

14-
from .gbrl_wrapper import (GBTWrapper, SeparateActorCriticWrapper,
15-
SharedActorCriticWrapper, )
16-
from .gbt import GBRL
17-
from .utils import (setup_optimizer, clip_grad_norm, numerical_dtype,
14+
from gbrl.gbrl_wrapper import GBTWrapper, SeparateActorCriticWrapper, SharedActorCriticWrapper,
15+
from gbrl.gbt import GBRL
16+
from gbrl.utils import (setup_optimizer, clip_grad_norm, numerical_dtype,
1817
concatenate_arrays, validate_array, constant_like,
1918
tensor_to_leaf)
2019

gbrl/gbrl_wrapper.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@
1313
import numpy as np
1414
import torch as th
1515

16-
from .gbrl_cpp import GBRL as GBRL_CPP
17-
from .utils import (get_input_dim, get_poly_vectors,
16+
from gbrl import GBRL_CPP
17+
from gbrl.utils import (get_input_dim, get_poly_vectors,
1818
to_numpy,
1919
numerical_dtype,
2020
get_tensor_info,

gbrl/gbt.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@
1212
import torch as th
1313

1414

15-
from .gbrl_wrapper import GBTWrapper
16-
from .utils import setup_optimizer, clip_grad_norm, validate_array
15+
from gbrl.gbrl_wrapper import GBTWrapper
16+
from gbrl.utils import setup_optimizer, clip_grad_norm, validate_array
1717

1818

1919
class GBRL:

gbrl/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import torch as th
1212
from scipy.special import binom
1313

14-
from .config import APPROVED_OPTIMIZERS, VALID_OPTIMIZER_ARGS
14+
from gbrl.config import APPROVED_OPTIMIZERS, VALID_OPTIMIZER_ARGS
1515

1616
import numpy as np
1717
# Define custom dtypes

0 commit comments

Comments
 (0)