@@ -71,6 +71,20 @@ struct Shape {
7171 }
7272};
7373
74+ inline std::ostream& operator <<(std::ostream& os, const Shape& shape)
75+ {
76+ int size = shape.rank ;
77+ os << " Shape: [" ;
78+ for (int i=0 ;i<size-1 ;i++){
79+ os << shape.data [i] << " ," ;
80+ }
81+ if ( size != 0 ) {
82+ os << shape.data [size-1 ];
83+ }
84+ os << " ]" ;
85+ return os;
86+ }
87+
7488/* *
7589 * @brief Returns the number of elements in a tensor with the given shape,
7690 * which is equal to the product of the dimensions.
@@ -210,30 +224,30 @@ enum NumType {
210224/* *
211225 * @brief Returns the number of bytes of a number type.
212226 */
213- inline size_t sizeBytes (const NumType &type) {
227+ inline size_t sizeBytes (const NumType &type, int numElements = 1 ) {
214228 switch (type) {
215229 case kf16:
216- return sizeof (half);
230+ return sizeof (half) * numElements ;
217231 case kf32:
218- return sizeof (float );
232+ return sizeof (float ) * numElements ;
219233 case kf64:
220- return sizeof (double );
234+ return sizeof (double ) * numElements ;
221235 case ki8:
222- return sizeof (int8_t );
236+ return sizeof (uint32_t ) * ((numElements + 3 ) / 4 );
223237 case ki16:
224- return sizeof (int16_t );
238+ return sizeof (uint32_t ) * ((numElements + 1 ) / 2 );
225239 case ki32:
226- return sizeof (int32_t );
240+ return sizeof (int32_t ) * numElements ;
227241 case ki64:
228- return sizeof (int64_t );
242+ return sizeof (int64_t ) * numElements ;
229243 case ku8:
230- return sizeof (uint8_t );
244+ return sizeof (uint32_t ) * ((numElements + 3 ) / 4 );
231245 case ku16:
232- return sizeof (uint16_t );
246+ return sizeof (uint32_t ) * ((numElements + 1 ) / 2 );
233247 case ku32:
234- return sizeof (uint32_t );
248+ return sizeof (uint32_t ) * numElements ;
235249 case ku64:
236- return sizeof (uint64_t );
250+ return sizeof (uint64_t ) * numElements ;
237251 default :
238252 LOG (kDefLog , kError , " Invalid NumType in size calculation." );
239253 return 0 ;
@@ -697,7 +711,7 @@ inline Tensor createTensor(TensorPool &pool, WGPUDevice &device,
697711 WGPUBufferUsage_CopySrc) {
698712 LOG (kDefLog , kTrace , " Creating tensor" );
699713 size_t numElements = size (shape);
700- size_t size = sizeBytes (dtype) * numElements;
714+ size_t size = sizeBytes (dtype, numElements) ;
701715 WGPUBufferDescriptor bufferDesc = {
702716 .label = {.data = nullptr , .length = 0 },
703717 .usage = usage,
@@ -828,7 +842,10 @@ inline Tensor createTensor(Context &ctx, const Shape &shape, NumType dtype,
828842 // unpacking
829843 packed[idx] |= (static_cast <uint8_t >(data[i]) << shift);
830844 }
831- return createTensor (ctx, shape, ki32, packed.data ());
845+ Tensor tensor = createTensor (ctx, shape, ki8);
846+ wgpuQueueWriteBuffer (ctx.queue , tensor.data .buffer , 0 , packed.data (),
847+ tensor.data .size );
848+ return tensor;
832849}
833850
834851// Overload for int16_t: pack two 16‑bit ints into one 32‑bit integer
@@ -843,7 +860,10 @@ inline Tensor createTensor(Context &ctx, const Shape &shape, NumType dtype,
843860 size_t shift = (i % 2 ) * 16 ;
844861 packed[idx] |= (static_cast <uint16_t >(data[i]) << shift);
845862 }
846- return createTensor (ctx, shape, ki32, packed.data ());
863+ Tensor tensor = createTensor (ctx, shape, ki16);
864+ wgpuQueueWriteBuffer (ctx.queue , tensor.data .buffer , 0 , packed.data (),
865+ tensor.data .size );
866+ return tensor;
847867}
848868
849869// Overload for int64_t: pack each 64‑bit int into two 32‑bit integers
@@ -857,7 +877,10 @@ inline Tensor createTensor(Context &ctx, const Shape &shape, NumType dtype,
857877 packed[2 * i] = static_cast <int32_t >(val & 0xFFFFFFFF );
858878 packed[2 * i + 1 ] = static_cast <int32_t >((val >> 32 ) & 0xFFFFFFFF );
859879 }
860- return createTensor (ctx, shape, ki32, packed.data ());
880+ Tensor tensor = createTensor (ctx, shape, ki64);
881+ wgpuQueueWriteBuffer (ctx.queue , tensor.data .buffer , 0 , packed.data (),
882+ tensor.data .size );
883+ return tensor;
861884}
862885
863886inline Tensor createTensor (Context &ctx, const Shape &shape, NumType dtype,
@@ -885,7 +908,10 @@ inline Tensor createTensor(Context &ctx, const Shape &shape, NumType dtype,
885908 size_t shift = (i % 4 ) * 8 ;
886909 packed[idx] |= (static_cast <uint32_t >(data[i]) << shift);
887910 }
888- return createTensor (ctx, shape, ku32, packed.data ());
911+ Tensor tensor = createTensor (ctx, shape, ku8);
912+ wgpuQueueWriteBuffer (ctx.queue , tensor.data .buffer , 0 , packed.data (),
913+ tensor.data .size );
914+ return tensor;
889915}
890916
891917// Overload for uint16_t: pack two 16‑bit integers into one 32‑bit unsigned
@@ -901,7 +927,10 @@ inline Tensor createTensor(Context &ctx, const Shape &shape, NumType dtype,
901927 size_t shift = (i % 2 ) * 16 ;
902928 packed[idx] |= (static_cast <uint32_t >(data[i]) << shift);
903929 }
904- return createTensor (ctx, shape, ku32, packed.data ());
930+ Tensor tensor = createTensor (ctx, shape, ku16);
931+ wgpuQueueWriteBuffer (ctx.queue , tensor.data .buffer , 0 , packed.data (),
932+ tensor.data .size );
933+ return tensor;
905934}
906935
907936// Overload for uint64_t: pack each 64‑bit integer into two 32‑bit unsigned
@@ -916,7 +945,10 @@ inline Tensor createTensor(Context &ctx, const Shape &shape, NumType dtype,
916945 packed[2 * i] = static_cast <uint32_t >(val & 0xFFFFFFFF );
917946 packed[2 * i + 1 ] = static_cast <uint32_t >(val >> 32 );
918947 }
919- return createTensor (ctx, shape, ku32, packed.data ());
948+ Tensor tensor = createTensor (ctx, shape, ku64);
949+ wgpuQueueWriteBuffer (ctx.queue , tensor.data .buffer , 0 , packed.data (),
950+ tensor.data .size );
951+ return tensor;
920952}
921953
922954/* *
@@ -1987,7 +2019,7 @@ inline void toCPU(Context &ctx, WGPUBuffer buffer, NumType dtype, void *output,
19872019 case kf32:
19882020 case ku32:
19892021 case ki32: {
1990- size_t byteSize = numElements * sizeBytes (dtype);
2022+ size_t byteSize = sizeBytes (dtype, numElements );
19912023 toCPU (ctx, buffer, output, byteSize, sourceOffset);
19922024 break ;
19932025 }
0 commit comments