Skip to content

Commit 0e08dea

Browse files
Xharktensorflower-gardener
authored andcommitted
Remove optimizer to compare model size and compare the zipped model size to pass the test. And enable test because now it passed.
PiperOrigin-RevId: 476310649
1 parent f8e874f commit 0e08dea

File tree

1 file changed

+18
-25
lines changed

1 file changed

+18
-25
lines changed

tensorflow_model_optimization/python/core/common/keras/compression/algorithms/weight_clustering_test.py

+18-25
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
import os
1818
import tempfile
19-
import unittest
19+
import zipfile
2020

2121
import tensorflow as tf
2222

@@ -63,33 +63,27 @@ def _train_model(model):
6363

6464
def _save_as_saved_model(model):
6565
saved_model_dir = tempfile.mkdtemp()
66-
model.save(saved_model_dir)
66+
model.save(saved_model_dir, include_optimizer=False)
6767
return saved_model_dir
6868

6969

70-
def _get_directory_size_in_bytes(directory):
71-
total = 0
72-
try:
73-
for entry in os.scandir(directory):
74-
if entry.is_file():
75-
# if it's a file, use stat() function
76-
total += entry.stat().st_size
77-
elif entry.is_dir():
78-
# if it's a directory, recursively call this function
79-
total += _get_directory_size_in_bytes(entry.path)
80-
except NotADirectoryError:
81-
# if `directory` isn't a directory, get the file size then
82-
return os.path.getsize(directory)
83-
except PermissionError:
84-
# if for whatever reason we can't open the folder, return 0
85-
return 0
86-
return total
70+
def _get_zipped_directory_size(directory):
71+
"""Measures the compressed size of a directory."""
72+
with tempfile.TemporaryFile(suffix='.zip') as zipped_file:
73+
for root, _, files in os.walk(directory):
74+
for file in files:
75+
with zipfile.ZipFile(
76+
zipped_file, 'w', compression=zipfile.ZIP_DEFLATED) as f:
77+
f.write(os.path.join(root, file),
78+
os.path.relpath(os.path.join(root, file),
79+
os.path.join(directory, '..')))
80+
81+
zipped_file.seek(0, 2)
82+
return os.fstat(zipped_file.fileno()).st_size
8783

8884

8985
class FunctionalTest(tf.test.TestCase):
9086

91-
# TODO(b/246652360): Re-Enable the test once it is fixed.
92-
@unittest.skip('Test needs to be fixed')
9387
def testWeightClustering_TrainingE2E(self):
9488
number_of_clusters = 8
9589
model = _build_model()
@@ -118,12 +112,11 @@ def testWeightClustering_TrainingE2E(self):
118112
# Accuracy test.
119113
self.assertGreater(results[1], 0.85) # 0.8708
120114

121-
original_size = _get_directory_size_in_bytes(original_saved_model_dir)
122-
compressed_size = _get_directory_size_in_bytes(saved_model_dir)
115+
original_size = _get_zipped_directory_size(original_saved_model_dir)
116+
compressed_size = _get_zipped_directory_size(saved_model_dir)
123117

124118
# Compressed model size test.
125-
# TODO(tfmot): gzip compression can reduce file size much better.
126-
self.assertLess(compressed_size, original_size / 1.3)
119+
self.assertLess(compressed_size, original_size / 4.0)
127120

128121
def testWeightClustering_SingleLayer(self):
129122
number_of_clusters = 8

0 commit comments

Comments
 (0)