4
4
"""
5
5
6
6
from .. import nn
7
- from .base import LayerRef , Layer
8
7
9
8
10
- def relu (x : LayerRef ) -> Layer :
9
+ def relu (x : nn . LayerRef ) -> nn . Layer :
11
10
"""ReLU"""
12
11
return _activation (x , activation = "relu" )
13
12
14
13
15
- def elu (x : LayerRef ) -> Layer :
14
+ def elu (x : nn . LayerRef ) -> nn . Layer :
16
15
"""ELU https://arxiv.org/abs/1511.07289"""
17
16
return _activation (x , activation = "elu" )
18
17
19
18
20
- def selu (x : LayerRef ) -> Layer :
19
+ def selu (x : nn . LayerRef ) -> nn . Layer :
21
20
"""SELU https://arxiv.org/abs/1706.02515"""
22
21
return _activation (x , activation = "selu" )
23
22
24
23
25
- def gelu (x : LayerRef ) -> Layer :
24
+ def gelu (x : nn . LayerRef ) -> nn . Layer :
26
25
"""GELU https://arxiv.org/abs/1606.08415"""
27
26
return _activation (x , activation = "gelu" )
28
27
29
28
30
- def exp (x : LayerRef ) -> Layer :
29
+ def exp (x : nn . LayerRef ) -> nn . Layer :
31
30
"""exp"""
32
31
return _activation (x , activation = "exp" )
33
32
34
33
35
- def log (x : LayerRef ) -> Layer :
34
+ def log (x : nn . LayerRef ) -> nn . Layer :
36
35
"""log"""
37
36
return _activation (x , activation = "log" )
38
37
39
38
40
- def tanh (x : LayerRef ) -> Layer :
39
+ def tanh (x : nn . LayerRef ) -> nn . Layer :
41
40
"""tanh"""
42
41
return _activation (x , activation = "tanh" )
43
42
44
43
45
- def sigmoid (x : LayerRef ) -> Layer :
44
+ def sigmoid (x : nn . LayerRef ) -> nn . Layer :
46
45
"""sigmoid"""
47
46
return _activation (x , activation = "sigmoid" )
48
47
49
48
50
- def log_sigmoid (x : LayerRef ) -> Layer :
49
+ def log_sigmoid (x : nn . LayerRef ) -> nn . Layer :
51
50
"""log sigmoid"""
52
51
return _activation (x , activation = "log_sigmoid" )
53
52
54
53
55
- def swish (x : LayerRef ) -> Layer :
54
+ def swish (x : nn . LayerRef ) -> nn . Layer :
56
55
"""swish"""
57
56
return _activation (x , activation = "swish" )
58
57
@@ -61,14 +60,14 @@ def swish(x: LayerRef) -> Layer:
61
60
softmax = nn .softmax
62
61
63
62
64
- def log_softmax (x : LayerRef , ** kwargs ) -> Layer :
63
+ def log_softmax (x : nn . LayerRef , ** kwargs ) -> nn . Layer :
65
64
"""
66
65
Wraps :func:`nn.softmax` with log_space=True.
67
66
"""
68
67
return nn .softmax (x , log_space = True , ** kwargs )
69
68
70
69
71
- def _activation (x : LayerRef , activation : str ) -> Layer :
70
+ def _activation (x : nn . LayerRef , activation : str ) -> nn . Layer :
72
71
"""
73
72
RETURNN ActivationLayer.
74
73
Only for internal use.
0 commit comments