diff --git a/R/nn.R b/R/nn.R index 7b03918aee..080360f154 100644 --- a/R/nn.R +++ b/R/nn.R @@ -1,6 +1,16 @@ #' @include utils-data.R NULL +default_parent_env = function() { + env = parent.frame(2) + pe = parent.env(env) + if (all(sapply(c(".__active__", "self", "private", "super"), function(field) exists(field, pe)))) { + get("inherit", pe)$parent_env + } else { + env + } +} + get_inherited_classes <- function(inherit) { inherit_class <- inherit$public_fields$.classes # Filter out classes that we eventually add in our normal flow. @@ -496,7 +506,7 @@ is_nn_module <- function(x) { #' @export nn_module <- function(classname = NULL, inherit = nn_Module, ..., private = NULL, active = NULL, - parent_env = parent.frame()) { + parent_env = default_parent_env()) { if (inherits(inherit, "nn_module")) { inherit <- attr(inherit, "module") } diff --git a/_pkgdown.yml b/_pkgdown.yml index 2a8c126d26..9b2dc7a175 100644 --- a/_pkgdown.yml +++ b/_pkgdown.yml @@ -7,7 +7,7 @@ template: primary: "#7e1f77" params: ganalytics: G-9ZJSKW3L0N - + development: mode: auto @@ -62,11 +62,11 @@ navbar: - text: basic-nn-module href: articles/examples/basic-nn-module.html - text: dataset - href: articles/examples/dataset.html - -reference: + href: articles/examples/dataset.html + +reference: - title: Tensor creation utilities - contents: + contents: - torch_empty - torch_arange - torch_eye @@ -136,7 +136,7 @@ reference: contents: - starts_with("lr_") - title: Datasets - contents: + contents: - starts_with("dataset") - iterable_dataset - starts_with("dataloader") @@ -179,7 +179,7 @@ reference: contents: - starts_with("cuda_") - title: JIT - contents: + contents: - starts_with("jit") - title: Backends contents: