Skip to content

Commit 562229a

Browse files
committed
Update acuity.js (#1423)
1 parent 8ff5b00 commit 562229a

File tree

2 files changed

+29
-1
lines changed

2 files changed

+29
-1
lines changed

source/acuity.js

+14
Original file line numberDiff line numberDiff line change
@@ -33,13 +33,15 @@ acuity.Graph = class {
3333
this.nodes = [];
3434
this.inputs = [];
3535
this.outputs = [];
36+
this.metrics = [];
3637
const values = new Map();
3738
const value = (name) => {
3839
if (!values.has(name)) {
3940
values.set(name, { name, shape: null });
4041
}
4142
return values.get(name);
4243
};
44+
let totalFlops = 0;
4345
for (const [name, layer] of Object.entries(model.Layers)) {
4446
layer.inputs = layer.inputs.map((input) => {
4547
return value(input);
@@ -64,7 +66,19 @@ acuity.Graph = class {
6466
output.shape = shape;
6567
return output;
6668
});
69+
// Add other layer types (e.g., pooling, batch norm, etc.) as needed.
70+
if (layer.type === 'Conv2D') {
71+
const { kernelShape, inputShape, outputShape } = layer;
72+
const [kH, kW] = kernelShape;
73+
const [inC] = inputShape;
74+
const [outC, oH, oW] = outputShape;
75+
totalFlops += kH * kW * inC * oH * oW * outC;
76+
} else if (layer.type === 'Dense') {
77+
const { inputSize, outputSize } = layer;
78+
totalFlops += inputSize * outputSize;
79+
}
6780
}
81+
this.metrics.push(new acuity.Argument('flops', totalFlops));
6882
acuity.Inference.infer(model.Layers);
6983
for (const [name, obj] of values) {
7084
const type = new acuity.TensorType(null, new acuity.TensorShape(obj.shape));

source/view.js

+15-1
Original file line numberDiff line numberDiff line change
@@ -2855,7 +2855,7 @@ view.TextView = class extends view.Control {
28552855
super(context);
28562856
this.element = this.createElement('div', 'sidebar-item-value');
28572857
let className = 'sidebar-item-value-line';
2858-
if (value) {
2858+
if (value !== null && value !== undefined) {
28592859
const list = Array.isArray(value) ? value : [value];
28602860
for (const item of list) {
28612861
const line = this.createElement('div', className);
@@ -3641,6 +3641,13 @@ view.ModelSidebar = class extends view.ObjectSidebar {
36413641
this.addProperty(argument.name, argument.value);
36423642
}
36433643
}
3644+
const metrics = model.metrics;
3645+
if (Array.isArray(metrics) && metrics.length > 0) {
3646+
this.addHeader('Metrics');
3647+
for (const argument of metrics) {
3648+
this.addProperty(argument.name, argument.value);
3649+
}
3650+
}
36443651
if (graph) {
36453652
if (graph.version) {
36463653
this.addProperty('version', graph.version);
@@ -3675,6 +3682,13 @@ view.ModelSidebar = class extends view.ObjectSidebar {
36753682
this.addProperty(argument.name, argument.value);
36763683
}
36773684
}
3685+
const metrics = graph.metrics;
3686+
if (Array.isArray(metrics) && metrics.length > 0) {
3687+
this.addHeader('Metrics');
3688+
for (const argument of metrics) {
3689+
this.addProperty(argument.name, argument.value);
3690+
}
3691+
}
36783692
}
36793693
}
36803694

0 commit comments

Comments
 (0)