forked from tensorflow/tensorflow
-
Notifications
You must be signed in to change notification settings - Fork 12
/
Copy patharray_ops.py
5864 lines (4756 loc) · 203 KB
/
array_ops.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# Tests for this file live in python/kernel_tests/array_ops_test.py
"""Support for manipulating tensors."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numbers
import numpy as np
from tensorflow.python.eager import context
from tensorflow.python.framework import common_shapes
from tensorflow.python.framework import composite_tensor
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util
# 'Constant' gets imported in the module 'array_ops'.
from tensorflow.python.framework.constant_op import constant
from tensorflow.python.ops import gen_array_ops
from tensorflow.python.ops import gen_math_ops
# go/tf-wildcard-import
# pylint: disable=wildcard-import
from tensorflow.python.ops.gen_array_ops import *
from tensorflow.python.ops.gen_array_ops import reverse_v2 as reverse # pylint: disable=unused-import
from tensorflow.python.types import core
from tensorflow.python.util import deprecation
from tensorflow.python.util import dispatch
from tensorflow.python.util import nest
from tensorflow.python.util import tf_decorator
from tensorflow.python.util.tf_export import tf_export
# pylint: enable=wildcard-import
# Used for slicing to specify a new 1 size dimension
newaxis = None
tf_export("newaxis").export_constant(__name__, "newaxis")
# We override the 'slice' for the "slice" op, so we keep Python's
# existing 'slice' for later use in this module.
_BaseSlice = slice
@tf_export("reshape", v1=["reshape", "manip.reshape"])
@dispatch.add_dispatch_support
def reshape(tensor, shape, name=None): # pylint: disable=redefined-outer-name
r"""Reshapes a tensor.
Given `tensor`, this operation returns a new `tf.Tensor` that has the same
values as `tensor` in the same order, except with a new shape given by
`shape`.
>>> t1 = [[1, 2, 3],
... [4, 5, 6]]
>>> print(tf.shape(t1).numpy())
[2 3]
>>> t2 = tf.reshape(t1, [6])
>>> t2
<tf.Tensor: shape=(6,), dtype=int32,
numpy=array([1, 2, 3, 4, 5, 6], dtype=int32)>
>>> tf.reshape(t2, [3, 2])
<tf.Tensor: shape=(3, 2), dtype=int32, numpy=
array([[1, 2],
[3, 4],
[5, 6]], dtype=int32)>
The `tf.reshape` does not change the order of or the total number of elements
in the tensor, and so it can reuse the underlying data buffer. This makes it
a fast operation independent of how big of a tensor it is operating on.
>>> tf.reshape([1, 2, 3], [2, 2])
Traceback (most recent call last):
...
InvalidArgumentError: Input to reshape is a tensor with 3 values, but the
requested shape has 4
To instead reorder the data to rearrange the dimensions of a tensor, see
`tf.transpose`.
>>> t = [[1, 2, 3],
... [4, 5, 6]]
>>> tf.reshape(t, [3, 2]).numpy()
array([[1, 2],
[3, 4],
[5, 6]], dtype=int32)
>>> tf.transpose(t, perm=[1, 0]).numpy()
array([[1, 4],
[2, 5],
[3, 6]], dtype=int32)
If one component of `shape` is the special value -1, the size of that
dimension is computed so that the total size remains constant. In particular,
a `shape` of `[-1]` flattens into 1-D. At most one component of `shape` can
be -1.
>>> t = [[1, 2, 3],
... [4, 5, 6]]
>>> tf.reshape(t, [-1])
<tf.Tensor: shape=(6,), dtype=int32,
numpy=array([1, 2, 3, 4, 5, 6], dtype=int32)>
>>> tf.reshape(t, [3, -1])
<tf.Tensor: shape=(3, 2), dtype=int32, numpy=
array([[1, 2],
[3, 4],
[5, 6]], dtype=int32)>
>>> tf.reshape(t, [-1, 2])
<tf.Tensor: shape=(3, 2), dtype=int32, numpy=
array([[1, 2],
[3, 4],
[5, 6]], dtype=int32)>
`tf.reshape(t, [])` reshapes a tensor `t` with one element to a scalar.
>>> tf.reshape([7], []).numpy()
7
More examples:
>>> t = [1, 2, 3, 4, 5, 6, 7, 8, 9]
>>> print(tf.shape(t).numpy())
[9]
>>> tf.reshape(t, [3, 3])
<tf.Tensor: shape=(3, 3), dtype=int32, numpy=
array([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]], dtype=int32)>
>>> t = [[[1, 1], [2, 2]],
... [[3, 3], [4, 4]]]
>>> print(tf.shape(t).numpy())
[2 2 2]
>>> tf.reshape(t, [2, 4])
<tf.Tensor: shape=(2, 4), dtype=int32, numpy=
array([[1, 1, 2, 2],
[3, 3, 4, 4]], dtype=int32)>
>>> t = [[[1, 1, 1],
... [2, 2, 2]],
... [[3, 3, 3],
... [4, 4, 4]],
... [[5, 5, 5],
... [6, 6, 6]]]
>>> print(tf.shape(t).numpy())
[3 2 3]
>>> # Pass '[-1]' to flatten 't'.
>>> tf.reshape(t, [-1])
<tf.Tensor: shape=(18,), dtype=int32,
numpy=array([1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6],
dtype=int32)>
>>> # -- Using -1 to infer the shape --
>>> # Here -1 is inferred to be 9:
>>> tf.reshape(t, [2, -1])
<tf.Tensor: shape=(2, 9), dtype=int32, numpy=
array([[1, 1, 1, 2, 2, 2, 3, 3, 3],
[4, 4, 4, 5, 5, 5, 6, 6, 6]], dtype=int32)>
>>> # -1 is inferred to be 2:
>>> tf.reshape(t, [-1, 9])
<tf.Tensor: shape=(2, 9), dtype=int32, numpy=
array([[1, 1, 1, 2, 2, 2, 3, 3, 3],
[4, 4, 4, 5, 5, 5, 6, 6, 6]], dtype=int32)>
>>> # -1 is inferred to be 3:
>>> tf.reshape(t, [ 2, -1, 3])
<tf.Tensor: shape=(2, 3, 3), dtype=int32, numpy=
array([[[1, 1, 1],
[2, 2, 2],
[3, 3, 3]],
[[4, 4, 4],
[5, 5, 5],
[6, 6, 6]]], dtype=int32)>
Args:
tensor: A `Tensor`.
shape: A `Tensor`. Must be one of the following types: `int32`, `int64`.
Defines the shape of the output tensor.
name: Optional string. A name for the operation.
Returns:
A `Tensor`. Has the same type as `tensor`.
"""
result = gen_array_ops.reshape(tensor, shape, name)
tensor_util.maybe_set_static_shape(result, shape)
return result
@tf_export("fill")
@dispatch.add_dispatch_support
def fill(dims, value, name=None):
r"""Creates a tensor filled with a scalar value.
See also `tf.ones`, `tf.zeros`, `tf.one_hot`, `tf.eye`.
This operation creates a tensor of shape `dims` and fills it with `value`.
For example:
>>> tf.fill([2, 3], 9)
<tf.Tensor: shape=(2, 3), dtype=int32, numpy=
array([[9, 9, 9],
[9, 9, 9]], dtype=int32)>
`tf.fill` evaluates at graph runtime and supports dynamic shapes based on
other runtime `tf.Tensors`, unlike `tf.constant(value, shape=dims)`, which
embeds the value as a `Const` node.
Args:
dims: A 1-D sequence of non-negative numbers. Represents the shape of the
output `tf.Tensor`. Entries should be of type: `int32`, `int64`.
value: A value to fill the returned `tf.Tensor`.
name: Optional string. The name of the output `tf.Tensor`.
Returns:
A `tf.Tensor` with shape `dims` and the same dtype as `value`.
Raises:
InvalidArgumentError: `dims` contains negative entries.
NotFoundError: `dims` contains non-integer entries.
@compatibility(numpy)
Similar to `np.full`. In `numpy`, more parameters are supported. Passing a
number argument as the shape (`np.full(5, value)`) is valid in `numpy` for
specifying a 1-D shaped result, while TensorFlow does not support this syntax.
@end_compatibility
"""
result = gen_array_ops.fill(dims, value, name=name)
tensor_util.maybe_set_static_shape(result, dims)
return result
@tf_export("identity")
@dispatch.add_dispatch_support
def identity(input, name=None): # pylint: disable=redefined-builtin
r"""Return a Tensor with the same shape and contents as input.
The return value is not the same Tensor as the original, but contains the same
values. This operation is fast when used on the same device.
For example:
>>> a = tf.constant([0.78])
>>> a_identity = tf.identity(a)
>>> a.numpy()
array([0.78], dtype=float32)
>>> a_identity.numpy()
array([0.78], dtype=float32)
Calling `tf.identity` on a variable will make a Tensor that represents the
value of that variable at the time it is called. This is equivalent to calling
`<variable>.read_value()`.
>>> a = tf.Variable(5)
>>> a_identity = tf.identity(a)
>>> a.assign_add(1)
<tf.Variable ... shape=() dtype=int32, numpy=6>
>>> a.numpy()
6
>>> a_identity.numpy()
5
Args:
input: A `Tensor`.
name: A name for the operation (optional).
Returns:
A `Tensor`. Has the same type as `input`.
"""
if isinstance(input, composite_tensor.CompositeTensor):
return nest.map_structure(identity, input, expand_composites=True)
if context.executing_eagerly() and not hasattr(input, "graph"):
# Make sure we get an input with handle data attached from resource
# variables. Variables have correct handle data when graph building.
input = ops.convert_to_tensor(input)
ret = gen_array_ops.identity(input, name=name)
# Propagate handle data for happier shape inference for resource variables.
if hasattr(input, "_handle_data"):
ret._handle_data = input._handle_data # pylint: disable=protected-access
return ret
# pylint: disable=redefined-builtin,protected-access
@tf_export(v1=["expand_dims"])
@dispatch.add_dispatch_support
@deprecation.deprecated_args(None, "Use the `axis` argument instead", "dim")
def expand_dims(input, axis=None, name=None, dim=None):
"""Returns a tensor with a length 1 axis inserted at index `axis`.
Given a tensor `input`, this operation inserts a dimension of length 1 at the
dimension index `axis` of `input`'s shape. The dimension index follows Python
indexing rules: It's zero-based, a negative index it is counted backward
from the end.
This operation is useful to:
* Add an outer "batch" dimension to a single element.
* Align axes for broadcasting.
* To add an inner vector length axis to a tensor of scalars.
For example:
If you have a single image of shape `[height, width, channels]`:
>>> image = tf.zeros([10,10,3])
You can add an outer `batch` axis by passing `axis=0`:
>>> tf.expand_dims(image, axis=0).shape.as_list()
[1, 10, 10, 3]
The new axis location matches Python `list.insert(axis, 1)`:
>>> tf.expand_dims(image, axis=1).shape.as_list()
[10, 1, 10, 3]
Following standard Python indexing rules, a negative `axis` counts from the
end so `axis=-1` adds an inner most dimension:
>>> tf.expand_dims(image, -1).shape.as_list()
[10, 10, 3, 1]
This operation requires that `axis` is a valid index for `input.shape`,
following Python indexing rules:
```
-1-tf.rank(input) <= axis <= tf.rank(input)
```
This operation is related to:
* `tf.squeeze`, which removes dimensions of size 1.
* `tf.reshape`, which provides more flexible reshaping capability.
* `tf.sparse.expand_dims`, which provides this functionality for
`tf.SparseTensor`
Args:
input: A `Tensor`.
axis: 0-D (scalar). Specifies the dimension index at which to expand the
shape of `input`. Must be in the range `[-rank(input) - 1, rank(input)]`.
name: The name of the output `Tensor` (optional).
dim: 0-D (scalar). Equivalent to `axis`, to be deprecated.
Returns:
A `Tensor` with the same data as `input`, but its shape has an additional
dimension of size 1 added.
Raises:
ValueError: if either both or neither of `dim` and `axis` are specified.
"""
axis = deprecation.deprecated_argument_lookup("axis", axis, "dim", dim)
if axis is None:
raise ValueError("Must specify an axis argument to tf.expand_dims()")
return expand_dims_v2(input, axis, name)
@tf_export("expand_dims", v1=[])
@dispatch.add_dispatch_support
def expand_dims_v2(input, axis, name=None):
"""Returns a tensor with a length 1 axis inserted at index `axis`.
Given a tensor `input`, this operation inserts a dimension of length 1 at the
dimension index `axis` of `input`'s shape. The dimension index follows Python
indexing rules: It's zero-based, a negative index it is counted backward
from the end.
This operation is useful to:
* Add an outer "batch" dimension to a single element.
* Align axes for broadcasting.
* To add an inner vector length axis to a tensor of scalars.
For example:
If you have a single image of shape `[height, width, channels]`:
>>> image = tf.zeros([10,10,3])
You can add an outer `batch` axis by passing `axis=0`:
>>> tf.expand_dims(image, axis=0).shape.as_list()
[1, 10, 10, 3]
The new axis location matches Python `list.insert(axis, 1)`:
>>> tf.expand_dims(image, axis=1).shape.as_list()
[10, 1, 10, 3]
Following standard Python indexing rules, a negative `axis` counts from the
end so `axis=-1` adds an inner most dimension:
>>> tf.expand_dims(image, -1).shape.as_list()
[10, 10, 3, 1]
This operation requires that `axis` is a valid index for `input.shape`,
following Python indexing rules:
```
-1-tf.rank(input) <= axis <= tf.rank(input)
```
This operation is related to:
* `tf.squeeze`, which removes dimensions of size 1.
* `tf.reshape`, which provides more flexible reshaping capability.
* `tf.sparse.expand_dims`, which provides this functionality for
`tf.SparseTensor`
Args:
input: A `Tensor`.
axis: Integer specifying the dimension index at which to expand the
shape of `input`. Given an input of D dimensions, `axis` must be in range
`[-(D+1), D]` (inclusive).
name: Optional string. The name of the output `Tensor`.
Returns:
A tensor with the same data as `input`, with an additional dimension
inserted at the index specified by `axis`.
Raises:
ValueError: If `axis` is not specified.
InvalidArgumentError: If `axis` is out of range `[-(D+1), D]`.
"""
return gen_array_ops.expand_dims(input, axis, name)
# pylint: enable=redefined-builtin,protected-access
# Aliases for some automatically-generated names.
# pylint: disable=protected-access
@deprecation.deprecated("2016-11-30",
"This op will be removed after the deprecation date. "
"Please switch to tf.setdiff1d().")
def listdiff(x, y, out_idx=None, name=None):
return gen_array_ops.list_diff(x, y, out_idx, name)
listdiff.__doc__ = gen_array_ops.list_diff.__doc__ + "\n" + listdiff.__doc__
# pylint: enable=protected-access
# pylint: disable=undefined-variable
@deprecation.deprecated("2018-11-30",
"This op will be removed after the deprecation date. "
"Please switch to tf.sets.difference().")
@tf_export(v1=["setdiff1d"])
@dispatch.add_dispatch_support
def setdiff1d(x, y, index_dtype=dtypes.int32, name=None):
"""Computes the difference between two lists of numbers or strings.
Given a list x and a list y, this operation returns a list out that
represents all values that are in x but not in y. The returned list
out is sorted in the same order that the numbers appear in x
(duplicates are preserved). This operation also returns a list idx
that represents the position of each out element in x.
In other words:
```python
out[i] = x[idx[i]] for i in [0, 1, ..., len(out) - 1]
```
Example usage:
>>> x = [1, 2, 3, 4, 5, 6]
>>> y = [1, 3, 5]
>>> setdiff1d(x,y)
ListDiff(out=<tf.Tensor: id=2, shape=(3,), dtype=int32,
numpy=array([2, 4, 6], dtype=int32)>, idx=<tf.Tensor: id=3,
shape=(3,), dtype=int32, numpy=array([1, 3, 5], dtype=int32)>)
Args:
x: A Tensor. 1-D. Values to keep.
y: A Tensor. Must have the same type as x. 1-D. Values to remove.
out_idx: An optional tf.DType from: tf.int32, tf.int64. Defaults to
tf.int32.
name: A name for the operation (optional).
Returns:
A tuple of Tensor objects (out, idx).
out: A Tensor. Has the same type as x.
idx: A Tensor of type out_idx.
"""
return gen_array_ops.list_diff(x, y, index_dtype, name)
setdiff1d.__doc__ = gen_array_ops.list_diff.__doc__
@tf_export("broadcast_dynamic_shape")
@dispatch.add_dispatch_support
def broadcast_dynamic_shape(shape_x, shape_y):
"""Computes the shape of a broadcast given symbolic shapes.
When `shape_x` and `shape_y` are Tensors representing shapes (i.e. the result
of calling tf.shape on another Tensor) this computes a Tensor which is the
shape of the result of a broadcasting op applied in tensors of shapes
`shape_x` and `shape_y`.
This is useful when validating the result of a broadcasting operation when the
tensors do not have statically known shapes.
Example:
>>> shape_x = (1, 2, 3)
>>> shape_y = (5, 1, 3)
>>> tf.broadcast_dynamic_shape(shape_x, shape_y)
<tf.Tensor: shape=(3,), dtype=int32, numpy=array([5, 2, 3], ...>
Args:
shape_x: A rank 1 integer `Tensor`, representing the shape of x.
shape_y: A rank 1 integer `Tensor`, representing the shape of y.
Returns:
A rank 1 integer `Tensor` representing the broadcasted shape.
Raises:
InvalidArgumentError: If the two shapes are incompatible for
broadcasting.
"""
return gen_array_ops.broadcast_args(shape_x, shape_y)
@tf_export("broadcast_static_shape")
@dispatch.add_dispatch_support
def broadcast_static_shape(shape_x, shape_y):
"""Computes the shape of a broadcast given known shapes.
When `shape_x` and `shape_y` are fully known `TensorShape`s this computes a
`TensorShape` which is the shape of the result of a broadcasting op applied in
tensors of shapes `shape_x` and `shape_y`.
For example, if shape_x is `TensorShape([1, 2, 3])` and shape_y is
`TensorShape([5, 1, 3])`, the result is a TensorShape whose value is
`TensorShape([5, 2, 3])`.
This is useful when validating the result of a broadcasting operation when the
tensors have statically known shapes.
Example:
>>> shape_x = tf.TensorShape([1, 2, 3])
>>> shape_y = tf.TensorShape([5, 1 ,3])
>>> tf.broadcast_static_shape(shape_x, shape_y)
TensorShape([5, 2, 3])
Args:
shape_x: A `TensorShape`
shape_y: A `TensorShape`
Returns:
A `TensorShape` representing the broadcasted shape.
Raises:
ValueError: If the two shapes can not be broadcasted.
"""
return common_shapes.broadcast_shape(shape_x, shape_y)
@tf_export("shape", v1=[])
@dispatch.add_dispatch_support
def shape_v2(input, out_type=dtypes.int32, name=None):
# pylint: disable=redefined-builtin
"""Returns the shape of a tensor.
See also `tf.size`, `tf.rank`.
`tf.shape` returns a 1-D integer tensor representing the shape of `input`.
For example:
>>> t = tf.constant([[[1, 1, 1], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]]])
>>> tf.shape(t)
<tf.Tensor: shape=(3,), dtype=int32, numpy=array([2, 2, 3], dtype=int32)>
Note: When using symbolic tensors, such as when using the Keras API,
tf.shape() will return the shape of the symbolic tensor.
>>> a = tf.keras.layers.Input((None, 10))
>>> tf.shape(a)
<tf.Tensor ... shape=(3,) dtype=int32>
In these cases, using `tf.Tensor.shape` will return more informative results.
>>> a.shape
TensorShape([None, None, 10])
(The first `None` represents the as yet unknown batch size.)
`tf.shape` and `Tensor.shape` should be identical in eager mode. Within
`tf.function` or within a `compat.v1` context, not all dimensions may be
known until execution time. Hence when defining custom layers and models
for graph mode, prefer the dynamic `tf.shape(x)` over the static `x.shape`.
Args:
input: A `Tensor` or `SparseTensor`.
out_type: (Optional) The specified output type of the operation (`int32` or
`int64`). Defaults to `tf.int32`.
name: A name for the operation (optional).
Returns:
A `Tensor` of type `out_type`.
"""
return shape(input, name, out_type)
@tf_export(v1=["shape"])
@dispatch.add_dispatch_support
def shape(input, name=None, out_type=dtypes.int32):
# pylint: disable=redefined-builtin
"""Returns the shape of a tensor.
This operation returns a 1-D integer tensor representing the shape of `input`.
For example:
```python
t = tf.constant([[[1, 1, 1], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]]])
tf.shape(t) # [2, 2, 3]
```
Args:
input: A `Tensor` or `SparseTensor`.
name: A name for the operation (optional).
out_type: (Optional) The specified output type of the operation (`int32`
or `int64`). Defaults to `tf.int32`.
Returns:
A `Tensor` of type `out_type`.
"""
return shape_internal(input, name, optimize=True, out_type=out_type)
def shape_internal(input, name=None, optimize=True, out_type=dtypes.int32):
# pylint: disable=redefined-builtin
"""Returns the shape of a tensor.
Args:
input: A `Tensor` or `SparseTensor`.
name: A name for the operation (optional).
optimize: if true, encode the shape as a constant when possible.
out_type: (Optional) The specified output type of the operation (`int32` or
`int64`). Defaults to tf.int32.
Returns:
A `Tensor` of type `out_type`.
"""
with ops.name_scope(name, "Shape", [input]) as name:
if isinstance(
input, (sparse_tensor.SparseTensor, sparse_tensor.SparseTensorValue)):
return gen_math_ops.cast(input.dense_shape, out_type)
else:
if not context.executing_eagerly():
input = ops.convert_to_tensor(input)
input_shape = input.get_shape()
if optimize and input_shape.is_fully_defined():
return constant(input_shape.as_list(), out_type, name=name)
return gen_array_ops.shape(input, name=name, out_type=out_type)
@tf_export("shape_n")
@dispatch.add_dispatch_support
def shape_n(input, out_type=dtypes.int32, name=None):
# pylint: disable=redefined-builtin
"""Returns shape of tensors.
Args:
input: A list of at least 1 `Tensor` object with the same type.
out_type: The specified output type of the operation (`int32` or `int64`).
Defaults to `tf.int32`(optional).
name: A name for the operation (optional).
Returns:
A list with the same length as `input` of `Tensor` objects with
type `out_type`.
"""
return gen_array_ops.shape_n(input, out_type=out_type, name=name)
@tf_export("size", v1=[])
@dispatch.add_dispatch_support
def size_v2(input, out_type=dtypes.int32, name=None):
# pylint: disable=redefined-builtin
"""Returns the size of a tensor.
See also `tf.shape`.
Returns a 0-D `Tensor` representing the number of elements in `input`
of type `out_type`. Defaults to tf.int32.
For example:
>>> t = tf.constant([[[1, 1, 1], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]]])
>>> tf.size(t)
<tf.Tensor: shape=(), dtype=int32, numpy=12>
Args:
input: A `Tensor` or `SparseTensor`.
name: A name for the operation (optional).
out_type: (Optional) The specified non-quantized numeric output type of the
operation. Defaults to `tf.int32`.
Returns:
A `Tensor` of type `out_type`. Defaults to `tf.int32`.
@compatibility(numpy)
Equivalent to np.size()
@end_compatibility
"""
return size(input, name, out_type)
@tf_export(v1=["size"])
@dispatch.add_dispatch_support
def size(input, name=None, out_type=dtypes.int32):
# pylint: disable=redefined-builtin
"""Returns the size of a tensor.
Returns a 0-D `Tensor` representing the number of elements in `input`
of type `out_type`. Defaults to tf.int32.
For example:
```python
t = tf.constant([[[1, 1, 1], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]]])
tf.size(t) # 12
```
Args:
input: A `Tensor` or `SparseTensor`.
name: A name for the operation (optional).
out_type: (Optional) The specified non-quantized numeric output type of the
operation. Defaults to `tf.int32`.
Returns:
A `Tensor` of type `out_type`. Defaults to `tf.int32`.
@compatibility(numpy)
Equivalent to np.size()
@end_compatibility
"""
return size_internal(input, name, optimize=True, out_type=out_type)
def size_internal(input, name=None, optimize=True, out_type=dtypes.int32):
# pylint: disable=redefined-builtin,protected-access
"""Returns the size of a tensor.
Args:
input: A `Tensor` or `SparseTensor`.
name: A name for the operation (optional).
optimize: if true, encode the size as a constant when possible.
out_type: (Optional) The specified non-quantized numeric output type of the
operation. Defaults to `tf.int32`.
Returns:
A `Tensor` of type `out_type`. Defaults to `tf.int32`.
"""
if (context.executing_eagerly() and not hasattr(input, "graph") and
not isinstance(
input,
(sparse_tensor.SparseTensor, sparse_tensor.SparseTensorValue))):
input = ops.convert_to_tensor(input)
np_out_type = out_type.as_numpy_dtype
num_elements = np.prod(input._shape_tuple(), dtype=np_out_type) # pylint: disable=protected-access
return ops.convert_to_tensor(num_elements, dtype=out_type)
with ops.name_scope(name, "Size", [input]) as name:
if isinstance(
input, (sparse_tensor.SparseTensor, sparse_tensor.SparseTensorValue)):
return gen_math_ops.prod(
gen_math_ops.cast(input.dense_shape, out_type), 0, name=name)
else:
input = ops.convert_to_tensor(input)
input_shape = input.get_shape()
if optimize:
if input_shape.is_fully_defined():
return constant(input_shape.num_elements(), out_type, name=name)
if input_shape.dims and any(dim == 0 for dim in input_shape.dims):
return constant(0, out_type, name=name)
return gen_array_ops.size(input, name=name, out_type=out_type)
@tf_export("rank")
@dispatch.add_dispatch_support
def rank(input, name=None):
# pylint: disable=redefined-builtin
"""Returns the rank of a tensor.
See also `tf.shape`.
Returns a 0-D `int32` `Tensor` representing the rank of `input`.
For example:
```python
# shape of tensor 't' is [2, 2, 3]
t = tf.constant([[[1, 1, 1], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]]])
tf.rank(t) # 3
```
**Note**: The rank of a tensor is not the same as the rank of a matrix. The
rank of a tensor is the number of indices required to uniquely select each
element of the tensor. Rank is also known as "order", "degree", or "ndims."
Args:
input: A `Tensor` or `SparseTensor`.
name: A name for the operation (optional).
Returns:
A `Tensor` of type `int32`.
@compatibility(numpy)
Equivalent to np.ndim
@end_compatibility
"""
return rank_internal(input, name, optimize=True)
def rank_internal(input, name=None, optimize=True):
# pylint: disable=redefined-builtin
"""Returns the rank of a tensor.
Args:
input: A `Tensor` or `SparseTensor`.
name: A name for the operation (optional).
optimize: if true, encode the rank as a constant when possible.
Returns:
A `Tensor` of type `int32`.
"""
with ops.name_scope(name, "Rank", [input]) as name:
if isinstance(
input, (sparse_tensor.SparseTensor, sparse_tensor.SparseTensorValue)):
return gen_array_ops.size(input.dense_shape, name=name)
else:
input = ops.convert_to_tensor(input)
input_shape = input.get_shape()
if optimize and input_shape.ndims is not None:
return constant(input_shape.ndims, dtypes.int32, name=name)
return gen_array_ops.rank(input, name=name)
_SLICE_TYPE_ERROR = (
"Only integers, slices (`:`), ellipsis (`...`), "
"tf.newaxis (`None`) and scalar tf.int32/tf.int64 tensors are valid "
"indices")
_SUPPORTED_SLICE_DTYPES = (dtypes.int32, dtypes.int32_ref, dtypes.int64,
dtypes.int64_ref)
def _check_index(idx):
"""Check if a given value is a valid index into a tensor."""
if isinstance(idx, (numbers.Integral, tensor_shape.Dimension)):
return
# Optimistic check. Assumptions:
# * any object with a dtype is supported
# * any object with a dtype has a sizeable shape attribute.
dtype = getattr(idx, "dtype", None)
if (dtype is None or dtypes.as_dtype(dtype) not in _SUPPORTED_SLICE_DTYPES or
idx.shape and len(idx.shape) == 1):
# TODO(slebedev): IndexError seems more appropriate here, but it
# will break `_slice_helper` contract.
raise TypeError(_SLICE_TYPE_ERROR + ", got {!r}".format(idx))
def _is_undefined_dimension(d):
return isinstance(d, tensor_shape.Dimension) and d.value is None
@tf_export("__operators__.getitem", v1=[])
@dispatch.add_dispatch_support
def _slice_helper(tensor, slice_spec, var=None):
"""Overload for Tensor.__getitem__.
This operation extracts the specified region from the tensor.
The notation is similar to NumPy with the restriction that
currently only support basic indexing. That means that
using a non-scalar tensor as input is not currently allowed.
Some useful examples:
```python
# Strip leading and trailing 2 elements
foo = tf.constant([1,2,3,4,5,6])
print(foo[2:-2].eval()) # => [3,4]
# Skip every other row and reverse the order of the columns
foo = tf.constant([[1,2,3], [4,5,6], [7,8,9]])
print(foo[::2,::-1].eval()) # => [[3,2,1], [9,8,7]]
# Use scalar tensors as indices on both dimensions
print(foo[tf.constant(0), tf.constant(2)].eval()) # => 3
# Insert another dimension
foo = tf.constant([[1,2,3], [4,5,6], [7,8,9]])
print(foo[tf.newaxis, :, :].eval()) # => [[[1,2,3], [4,5,6], [7,8,9]]]
print(foo[:, tf.newaxis, :].eval()) # => [[[1,2,3]], [[4,5,6]], [[7,8,9]]]
print(foo[:, :, tf.newaxis].eval()) # => [[[1],[2],[3]], [[4],[5],[6]],
[[7],[8],[9]]]
# Ellipses (3 equivalent operations)
foo = tf.constant([[1,2,3], [4,5,6], [7,8,9]])
print(foo[tf.newaxis, :, :].eval()) # => [[[1,2,3], [4,5,6], [7,8,9]]]
print(foo[tf.newaxis, ...].eval()) # => [[[1,2,3], [4,5,6], [7,8,9]]]
print(foo[tf.newaxis].eval()) # => [[[1,2,3], [4,5,6], [7,8,9]]]
# Masks
foo = tf.constant([[1,2,3], [4,5,6], [7,8,9]])
print(foo[foo > 2].eval()) # => [3, 4, 5, 6, 7, 8, 9]
```
Notes:
- `tf.newaxis` is `None` as in NumPy.
- An implicit ellipsis is placed at the end of the `slice_spec`
- NumPy advanced indexing is currently not supported.
Purpose in the API:
This method is exposed in TensorFlow's API so that library developers
can register dispatching for `Tensor.__getitem__` to allow it to handle
custom composite tensors & other custom objects.
The API symbol is not intended to be called by users directly and does
appear in TensorFlow's generated documentation.
Args:
tensor: An ops.Tensor object.
slice_spec: The arguments to Tensor.__getitem__.
var: In the case of variable slice assignment, the Variable object to slice
(i.e. tensor is the read-only view of this variable).
Returns:
The appropriate slice of "tensor", based on "slice_spec".
Raises:
ValueError: If a slice range is negative size.
TypeError: If the slice indices aren't int, slice, ellipsis,
tf.newaxis or scalar int32/int64 tensors.
"""
if isinstance(slice_spec, bool) or \
(isinstance(slice_spec, ops.Tensor) and slice_spec.dtype == dtypes.bool) or \
(isinstance(slice_spec, np.ndarray) and slice_spec.dtype == bool):
return boolean_mask(tensor=tensor, mask=slice_spec)
if not isinstance(slice_spec, (list, tuple)):
slice_spec = [slice_spec]
begin, end, strides = [], [], []
index = 0
new_axis_mask, shrink_axis_mask = 0, 0
begin_mask, end_mask = 0, 0
ellipsis_mask = 0
for s in slice_spec:
if isinstance(s, _BaseSlice):
if s.start is not None and not _is_undefined_dimension(s.start):
_check_index(s.start)
begin.append(s.start)
else:
begin.append(0)
begin_mask |= (1 << index)
if s.stop is not None and not _is_undefined_dimension(s.stop):
_check_index(s.stop)
end.append(s.stop)
else:
end.append(0)
end_mask |= (1 << index)
if s.step is not None and not _is_undefined_dimension(s.step):
_check_index(s.step)
strides.append(s.step)
else:
strides.append(1)
elif s is Ellipsis:
begin.append(0)
end.append(0)
strides.append(1)
ellipsis_mask |= (1 << index)
elif s is newaxis:
begin.append(0)
end.append(0)
strides.append(1)
new_axis_mask |= (1 << index)