Skip to content

Commit 6c0c3cb

Browse files
iclementineJiang BinFatJhon
authored
add inplace pointwise operators (#224)
* add inplace pointwise operators * add inplace autograd function(pay attention to save_for_backward & mark_dirty); add tests for inplace unary pointwise operations * add test cases for inplace binary pointwise operators * check out size && is_tensor list * change broadcast * check out_tensor overlapping * modify tag for overlapping * test internal_overlapping for only pointwise operation --------- Co-authored-by: Jiang Bin <[email protected]> Co-authored-by: FatJhon <[email protected]>
1 parent 88c81ec commit 6c0c3cb

28 files changed

+1172
-26
lines changed

src/flag_gems/__init__.py

+46
Original file line numberDiff line numberDiff line change
@@ -21,58 +21,93 @@ def enable(lib=aten_lib, unused=None, registrar=registrar):
2121
current_work_registrar = registrar(
2222
(
2323
("abs", abs, Autograd.disable),
24+
("abs_", abs_, Autograd.disable),
2425
("add.Tensor", add, Autograd.disable),
26+
("add_.Tensor", add_, Autograd.disable),
2527
("addmm", addmm, Autograd.disable),
2628
("arange.start_step", arange_start, Autograd.disable),
2729
("arange.start", arange_start, Autograd.disable),
2830
("arange", arange, Autograd.disable),
2931
("batch_norm", batch_norm, Autograd.enable),
3032
("bitwise_and.Tensor", bitwise_and_tensor, Autograd.disable),
33+
("bitwise_and_.Tensor_", bitwise_and_tensor_, Autograd.disable),
3134
("bitwise_and.Scalar", bitwise_and_scalar, Autograd.disable),
35+
("bitwise_and_.Scalar", bitwise_and_scalar_, Autograd.disable),
3236
("bitwise_and.Scalar_Tensor", bitwise_and_scalar_tensor, Autograd.disable),
3337
("bitwise_not", bitwise_not, Autograd.disable),
38+
("bitwise_not_", bitwise_not_, Autograd.disable),
3439
("bitwise_or.Tensor", bitwise_or_tensor, Autograd.disable),
40+
("bitwise_or_.Tensor", bitwise_or_tensor_, Autograd.disable),
3541
("bitwise_or.Scalar", bitwise_or_scalar, Autograd.disable),
42+
("bitwise_or_.Scalar", bitwise_or_scalar_, Autograd.disable),
3643
("bitwise_or.Scalar_Tensor", bitwise_or_scalar_tensor, Autograd.disable),
3744
("bmm", bmm, Autograd.disable),
3845
("clamp", clamp, Autograd.disable),
46+
("clamp_", clamp_, Autograd.disable),
3947
("clamp.Tensor", clamp_tensor, Autograd.disable),
48+
("clamp_.Tensor", clamp_tensor_, Autograd.disable),
4049
("cos", cos, Autograd.disable),
50+
("cos_", cos_, Autograd.disable),
4151
("pad", pad, Autograd.disable),
4252
("constant_pad_nd", constant_pad_nd, Autograd.disable),
4353
("cumsum", cumsum, Autograd.disable),
4454
("cummin", cummin, Autograd.disable),
4555
("div.Tensor", true_divide, Autograd.disable),
56+
("div_.Tensor", true_divide_, Autograd.disable),
4657
("div.Scalar", true_divide, Autograd.disable),
58+
("div_.Scalar", true_divide_, Autograd.disable),
4759
("div.Tensor_mode", div_mode, Autograd.disable),
60+
("div_.Tensor_mode", div_mode_, Autograd.disable),
4861
("div.Scalar_mode", div_mode, Autograd.disable),
62+
("div_.Scalar_mode", div_mode_, Autograd.disable),
4963
(
5064
"divide.Tensor",
5165
true_divide,
5266
Autograd.disable,
5367
), # divide, an alias for div
68+
(
69+
"divide_.Tensor",
70+
true_divide_,
71+
Autograd.disable,
72+
), # divide, an alias for div
5473
("divide.Scalar", true_divide, Autograd.disable),
74+
("divide_.Scalar", true_divide_, Autograd.disable),
5575
("divide.Tensor_mode", div_mode, Autograd.disable),
76+
("divide_.Tensor_mode", div_mode_, Autograd.disable),
5677
("divide.Scalar_mode", div_mode, Autograd.disable),
78+
("divide_.Scalar_mode", div_mode_, Autograd.disable),
5779
(
5880
"true_divide.Tensor",
5981
true_divide,
6082
Autograd.disable,
6183
), # true_divide, an alias for div
84+
(
85+
"true_divide_.Tensor",
86+
true_divide_,
87+
Autograd.disable,
88+
), # true_divide, an alias for div
6289
("true_divide.Scalar", true_divide, Autograd.disable),
90+
("true_divide_.Scalar", true_divide_, Autograd.disable),
6391
("floor_divide", floor_divide, Autograd.disable),
6492
("floor_divide.Scalar", floor_divide, Autograd.disable),
6593
("remainder.Tensor", remainder, Autograd.disable),
94+
("remainder_.Tensor", remainder_, Autograd.disable),
95+
("remainder.Scalar", remainder, Autograd.disable),
96+
("remainder_.Scalar", remainder_, Autograd.disable),
97+
("remainder.Scalar_Tensor", remainder, Autograd.disable),
6698
("native_dropout", native_dropout, Autograd.enable),
6799
("erf", erf, Autograd.disable),
100+
("erf_", erf_, Autograd.disable),
68101
("embedding", embedding, Autograd.enable),
69102
("eq.Tensor", eq, Autograd.disable),
70103
("eq.Scalar", eq_scalar, Autograd.disable),
71104
("exp", exp, Autograd.disable),
105+
("exp_", exp_, Autograd.disable),
72106
("exponential_", exponential_, Autograd.disable),
73107
("ge.Tensor", ge, Autograd.disable),
74108
("ge.Scalar", ge_scalar, Autograd.disable),
75109
("gelu", gelu, Autograd.enable),
110+
("gelu_", gelu_, Autograd.enable),
76111
("native_group_norm", group_norm, Autograd.enable),
77112
("_weight_norm_interface", weight_norm_interface, Autograd.enable),
78113
("_weight_norm", weight_norm, Autograd.enable),
@@ -118,19 +153,30 @@ def enable(lib=aten_lib, unused=None, registrar=registrar):
118153
("ne.Tensor", ne, Autograd.disable),
119154
("ne.Scalar", ne_scalar, Autograd.disable),
120155
("neg", neg, Autograd.disable),
156+
("neg_", neg_, Autograd.disable),
121157
("pow.Scalar", pow_scalar, Autograd.disable),
122158
("pow.Tensor_Scalar", pow_tensor_scalar, Autograd.disable),
159+
("pow_.Scalar", pow_tensor_scalar_, Autograd.disable),
123160
("pow.Tensor_Tensor", pow_tensor_tensor, Autograd.disable),
161+
("pow_.Tensor", pow_tensor_tensor_, Autograd.disable),
124162
("reciprocal", reciprocal, Autograd.disable),
163+
("reciprocal_", reciprocal_, Autograd.disable),
125164
("relu", relu, Autograd.enable),
165+
("relu_", relu_, Autograd.enable),
126166
("rsqrt", rsqrt, Autograd.disable),
167+
("rsqrt_", rsqrt_, Autograd.disable),
127168
("sigmoid", sigmoid, Autograd.enable),
169+
("sigmoid_", sigmoid_, Autograd.enable),
128170
("silu", silu, Autograd.enable),
171+
("silu_", silu_, Autograd.enable),
129172
("sin", sin, Autograd.disable),
173+
("sin_", sin_, Autograd.disable),
130174
("softmax.int", softmax, Autograd.enable),
131175
("sort", sort, Autograd.disable),
132176
("sub.Tensor", sub, Autograd.disable),
177+
("sub_.Tensor", sub_, Autograd.disable),
133178
("tanh", tanh, Autograd.enable),
179+
("tanh_", tanh_, Autograd.enable),
134180
("triu", triu, Autograd.disable),
135181
# ("topk", topk, Autograd.disable),
136182
("var_mean.correction", var_mean, Autograd.disable),

src/flag_gems/ops/__init__.py

+73-21
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
from .abs import abs
2-
from .add import add
1+
from .abs import abs, abs_
2+
from .add import add, add_
33
from .addmm import addmm
44
from .all import all, all_dim, all_dims
55
from .amax import amax
@@ -11,40 +11,57 @@
1111
from .batch_norm import batch_norm
1212
from .bitwise_and import (
1313
bitwise_and_scalar,
14+
bitwise_and_scalar_,
1415
bitwise_and_scalar_tensor,
1516
bitwise_and_tensor,
17+
bitwise_and_tensor_,
18+
)
19+
from .bitwise_not import bitwise_not, bitwise_not_
20+
from .bitwise_or import (
21+
bitwise_or_scalar,
22+
bitwise_or_scalar_,
23+
bitwise_or_scalar_tensor,
24+
bitwise_or_tensor,
25+
bitwise_or_tensor_,
1626
)
17-
from .bitwise_not import bitwise_not
18-
from .bitwise_or import bitwise_or_scalar, bitwise_or_scalar_tensor, bitwise_or_tensor
1927
from .bmm import bmm
2028
from .cat import cat
21-
from .clamp import clamp, clamp_tensor
29+
from .clamp import clamp, clamp_, clamp_tensor, clamp_tensor_
2230
from .conv1d import conv1d
2331
from .conv2d import conv2d
2432
from .conv_depthwise2d import _conv_depthwise2d
25-
from .cos import cos
33+
from .cos import cos, cos_
2634
from .count_nonzero import count_nonzero
2735
from .cross_entropy_loss import cross_entropy_loss
2836
from .cummin import cummin
2937
from .cumsum import cumsum, normed_cumsum
3038
from .diag import diag
3139
from .diag_embed import diag_embed
3240
from .diagonal import diagonal_backward
33-
from .div import div_mode, floor_divide, remainder, true_divide
41+
from .div import (
42+
div_mode,
43+
div_mode_,
44+
floor_divide,
45+
floor_divide_,
46+
remainder,
47+
remainder_,
48+
true_divide,
49+
true_divide_,
50+
)
3451
from .dropout import native_dropout
3552
from .elu import elu
3653
from .embedding import embedding
3754
from .eq import eq, eq_scalar
38-
from .erf import erf
39-
from .exp import exp
55+
from .erf import erf, erf_
56+
from .exp import exp, exp_
4057
from .exponential_ import exponential_
4158
from .fill import fill_scalar, fill_tensor
4259
from .flip import flip
4360
from .full import full
4461
from .full_like import full_like
4562
from .gather import gather, gather_backward
4663
from .ge import ge, ge_scalar
47-
from .gelu import gelu
64+
from .gelu import gelu, gelu_
4865
from .groupnorm import group_norm
4966
from .gt import gt, gt_scalar
5067
from .hstack import hstack
@@ -76,11 +93,11 @@
7693
from .minimum import minimum
7794
from .mm import mm
7895
from .mse_loss import mse_loss
79-
from .mul import mul
96+
from .mul import mul, mul_
8097
from .multinomial import multinomial
8198
from .mv import mv
8299
from .ne import ne, ne_scalar
83-
from .neg import neg
100+
from .neg import neg, neg_
84101
from .nllloss import (
85102
nll_loss2d_backward,
86103
nll_loss2d_forward,
@@ -93,16 +110,22 @@
93110
from .ones_like import ones_like
94111
from .outer import outer
95112
from .pad import constant_pad_nd, pad
96-
from .pow import pow_scalar, pow_tensor_scalar, pow_tensor_tensor
113+
from .pow import (
114+
pow_scalar,
115+
pow_tensor_scalar,
116+
pow_tensor_scalar_,
117+
pow_tensor_tensor,
118+
pow_tensor_tensor_,
119+
)
97120
from .prod import prod, prod_dim
98121
from .quantile import quantile
99122
from .rand import rand
100123
from .rand_like import rand_like
101124
from .randn import randn
102125
from .randn_like import randn_like
103126
from .randperm import randperm
104-
from .reciprocal import reciprocal
105-
from .relu import relu
127+
from .reciprocal import reciprocal, reciprocal_
128+
from .relu import relu, relu_
106129
from .repeat import repeat
107130
from .repeat_interleave import (
108131
repeat_interleave_self_int,
@@ -112,19 +135,19 @@
112135
from .resolve_conj import resolve_conj
113136
from .resolve_neg import resolve_neg
114137
from .rms_norm import rms_norm
115-
from .rsqrt import rsqrt
138+
from .rsqrt import rsqrt, rsqrt_
116139
from .scatter import scatter
117140
from .select_scatter import select_scatter
118-
from .sigmoid import sigmoid
119-
from .silu import silu
120-
from .sin import sin
141+
from .sigmoid import sigmoid, sigmoid_
142+
from .silu import silu, silu_
143+
from .sin import sin, sin_
121144
from .slice_scatter import slice_scatter
122145
from .softmax import softmax
123146
from .sort import sort
124147
from .stack import stack
125-
from .sub import sub
148+
from .sub import sub, sub_
126149
from .sum import sum, sum_dim
127-
from .tanh import tanh
150+
from .tanh import tanh, tanh_
128151
from .tile import tile
129152
from .topk import topk
130153
from .triu import triu
@@ -151,22 +174,32 @@
151174
"any_dim",
152175
"any_dims",
153176
"add",
177+
"add_",
154178
"abs",
179+
"abs_",
155180
"addmm",
156181
"arange",
157182
"arange_start",
158183
"batch_norm",
159184
"bitwise_and_tensor",
185+
"bitwise_and_tensor_",
160186
"bitwise_and_scalar",
187+
"bitwise_and_scalar_",
161188
"bitwise_and_scalar_tensor",
162189
"bitwise_not",
190+
"bitwise_not_",
163191
"bitwise_or_tensor",
192+
"bitwise_or_tensor_",
164193
"bitwise_or_scalar",
194+
"bitwise_or_scalar_",
165195
"bitwise_or_scalar_tensor",
166196
"bmm",
167197
"clamp",
198+
"clamp_",
168199
"clamp_tensor",
200+
"clamp_tensor_",
169201
"cos",
202+
"cos_",
170203
"count_nonzero",
171204
"diag",
172205
"diag_embed",
@@ -178,18 +211,24 @@
178211
"cumsum",
179212
"normed_cumsum",
180213
"true_divide",
214+
"true_divide_",
181215
"div_mode",
216+
"div_mode_",
182217
"floor_divide",
218+
"floor_divide_",
183219
"remainder",
220+
"remainder_",
184221
"zeros",
185222
"ones",
186223
"full",
187224
"native_dropout",
188225
"erf",
226+
"erf_",
189227
"embedding",
190228
"eq",
191229
"eq_scalar",
192230
"exp",
231+
"exp_",
193232
"fill_scalar",
194233
"fill_tensor",
195234
"exponential_",
@@ -202,6 +241,7 @@
202241
"ge",
203242
"ge_scalar",
204243
"gelu",
244+
"gelu_",
205245
"group_norm",
206246
"gt",
207247
"gt_scalar",
@@ -224,6 +264,7 @@
224264
"mean_dim",
225265
"mm",
226266
"mul",
267+
"mul_",
227268
"multinomial",
228269
"maximum",
229270
"minimum",
@@ -242,19 +283,30 @@
242283
"ne",
243284
"ne_scalar",
244285
"neg",
286+
"neg_",
245287
"pow_scalar",
246288
"pow_tensor_scalar",
247289
"pow_tensor_tensor",
290+
"pow_tensor_scalar_",
291+
"pow_tensor_tensor_",
248292
"reciprocal",
293+
"reciprocal_",
249294
"relu",
295+
"relu_",
250296
"rsqrt",
297+
"rsqrt_",
251298
"scatter",
252299
"sigmoid",
300+
"sigmoid_",
253301
"silu",
302+
"silu_",
254303
"sin",
304+
"sin_",
255305
"softmax",
256306
"sub",
307+
"sub_",
257308
"tanh",
309+
"tanh_",
258310
"tile",
259311
"triu",
260312
"topk",

src/flag_gems/ops/abs.py

+6
Original file line numberDiff line numberDiff line change
@@ -15,3 +15,9 @@ def abs_func(x):
1515
def abs(A):
1616
logging.debug("GEMS ABS")
1717
return abs_func(A)
18+
19+
20+
def abs_(A):
21+
logging.debug("GEMS ABS_")
22+
abs_func(A, out0=A)
23+
return A

src/flag_gems/ops/add.py

+12
Original file line numberDiff line numberDiff line change
@@ -38,3 +38,15 @@ def add(A, B, *, alpha=1):
3838
return add_func_scalar_tensor(A, B, alpha)
3939
else:
4040
return torch.tensor(A + B * alpha)
41+
42+
43+
def add_(A, B, *, alpha=1):
44+
logging.debug("GEMS ADD_")
45+
if isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor):
46+
return add_func(A, B, alpha, out0=A)
47+
elif isinstance(A, torch.Tensor):
48+
return add_func_tensor_scalar(A, B, alpha, out0=A)
49+
# elif isinstance(B, torch.Tensor):
50+
# return add_func_scalar_tensor(A, B, alpha, out0=A)
51+
else:
52+
raise ValueError("Unreachable.")

0 commit comments

Comments
 (0)