-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathtensor.go
More file actions
119 lines (101 loc) · 2.62 KB
/
Copy pathtensor.go
File metadata and controls
119 lines (101 loc) · 2.62 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
package tflite
/*
#include <stdio.h>
#include <tensorflow/lite/c/c_api.h>
#cgo LDFLAGS: -ltensorflowlite_c
#cgo linux LDFLAGS: -lm -ldl -lrt
*/
import "C"
import (
"fmt"
"reflect"
"unsafe"
)
// TensorType is types of the tensor.
type TensorType int
// Tensors type
const (
TfLiteNoType TensorType = iota
TfLiteFloat32
TfLiteInt32
TfLiteUInt8
TfLiteInt64
TfLiteString
TfLiteBool
TfLiteInt16
TfLiteComplex64
TfLiteInt8
)
// Tensor represents TensorFlow Lite Tensor.
type Tensor struct {
tensor *C.TfLiteTensor
}
// QuantizationParams
type QuantizationParams struct {
Scale float64
ZeroPoint int
}
// Type return TensorType.
func (t *Tensor) Type() TensorType {
return TensorType(C.TfLiteTensorType(t.tensor))
}
// NumDims return number of dimensions.
func (t *Tensor) NumDims() int {
return int(C.TfLiteTensorNumDims(t.tensor))
}
// Dim return dimension of the element specified by index.
func (t *Tensor) Dim(index int) int {
return int(C.TfLiteTensorDim(t.tensor, C.int32_t(index)))
}
// Shape return shape of the tensor.
func (t *Tensor) Shape() []int {
shape := make([]int, t.NumDims())
for i := 0; i < t.NumDims(); i++ {
shape[i] = t.Dim(i)
}
return shape
}
// ByteSize return byte size of the tensor.
func (t *Tensor) ByteSize() uint {
return uint(C.TfLiteTensorByteSize(t.tensor))
}
// Name return name of the tensor.
func (t *Tensor) Name() string {
return C.GoString(C.TfLiteTensorName(t.tensor))
}
// SetFloat32 sets float32s.
func (t *Tensor) SetFloat32(v []float32) error {
if t != nil {
ptr := C.TfLiteTensorData(t.tensor)
if t.Type() != TfLiteFloat32 || ptr == nil {
return fmt.Errorf("type error")
}
n := t.ByteSize() / 4
to := (*((*[1<<29 - 1]float32)(ptr)))[:n]
copy(to, v)
return nil
}
return fmt.Errorf("type error")
}
// GetFloat32 returns float32.
func (t *Tensor) GetFloat32() []float32 {
ptr := C.TfLiteTensorData(t.tensor)
n := t.ByteSize() / 4
return (*((*[1<<29 - 1]float32)(ptr)))[:n]
}
// FromBuffer copy Tensor from Buffer
func (t *Tensor) FromBuffer(b interface{}) Status {
return Status(C.TfLiteTensorCopyFromBuffer(t.tensor, unsafe.Pointer(reflect.ValueOf(b).Pointer()), C.size_t(t.ByteSize())))
}
// ToBuffer copy Tensor to Buffer
func (t *Tensor) ToBuffer(b interface{}) Status {
return Status(C.TfLiteTensorCopyToBuffer(t.tensor, unsafe.Pointer(reflect.ValueOf(b).Pointer()), C.size_t(t.ByteSize())))
}
//QuantizationParams return quantization parameters of a Tensor.
func (t *Tensor) QuantizationParams() QuantizationParams {
qp := C.TfLiteTensorQuantizationParams(t.tensor)
return QuantizationParams{
Scale: float64(qp.scale),
ZeroPoint: int(qp.zero_point),
}
}