|
16 | 16 |
|
17 | 17 | import os
|
18 | 18 | import tempfile
|
19 |
| -import unittest |
| 19 | +import zipfile |
20 | 20 |
|
21 | 21 | import tensorflow as tf
|
22 | 22 |
|
@@ -63,33 +63,27 @@ def _train_model(model):
|
63 | 63 |
|
64 | 64 | def _save_as_saved_model(model):
|
65 | 65 | saved_model_dir = tempfile.mkdtemp()
|
66 |
| - model.save(saved_model_dir) |
| 66 | + model.save(saved_model_dir, include_optimizer=False) |
67 | 67 | return saved_model_dir
|
68 | 68 |
|
69 | 69 |
|
70 |
| -def _get_directory_size_in_bytes(directory): |
71 |
| - total = 0 |
72 |
| - try: |
73 |
| - for entry in os.scandir(directory): |
74 |
| - if entry.is_file(): |
75 |
| - # if it's a file, use stat() function |
76 |
| - total += entry.stat().st_size |
77 |
| - elif entry.is_dir(): |
78 |
| - # if it's a directory, recursively call this function |
79 |
| - total += _get_directory_size_in_bytes(entry.path) |
80 |
| - except NotADirectoryError: |
81 |
| - # if `directory` isn't a directory, get the file size then |
82 |
| - return os.path.getsize(directory) |
83 |
| - except PermissionError: |
84 |
| - # if for whatever reason we can't open the folder, return 0 |
85 |
| - return 0 |
86 |
| - return total |
| 70 | +def _get_zipped_directory_size(directory): |
| 71 | + """Measures the compressed size of a directory.""" |
| 72 | + with tempfile.TemporaryFile(suffix='.zip') as zipped_file: |
| 73 | + for root, _, files in os.walk(directory): |
| 74 | + for file in files: |
| 75 | + with zipfile.ZipFile( |
| 76 | + zipped_file, 'w', compression=zipfile.ZIP_DEFLATED) as f: |
| 77 | + f.write(os.path.join(root, file), |
| 78 | + os.path.relpath(os.path.join(root, file), |
| 79 | + os.path.join(directory, '..'))) |
| 80 | + |
| 81 | + zipped_file.seek(0, 2) |
| 82 | + return os.fstat(zipped_file.fileno()).st_size |
87 | 83 |
|
88 | 84 |
|
89 | 85 | class FunctionalTest(tf.test.TestCase):
|
90 | 86 |
|
91 |
| - # TODO(b/246652360): Re-Enable the test once it is fixed. |
92 |
| - @unittest.skip('Test needs to be fixed') |
93 | 87 | def testWeightClustering_TrainingE2E(self):
|
94 | 88 | number_of_clusters = 8
|
95 | 89 | model = _build_model()
|
@@ -118,12 +112,11 @@ def testWeightClustering_TrainingE2E(self):
|
118 | 112 | # Accuracy test.
|
119 | 113 | self.assertGreater(results[1], 0.85) # 0.8708
|
120 | 114 |
|
121 |
| - original_size = _get_directory_size_in_bytes(original_saved_model_dir) |
122 |
| - compressed_size = _get_directory_size_in_bytes(saved_model_dir) |
| 115 | + original_size = _get_zipped_directory_size(original_saved_model_dir) |
| 116 | + compressed_size = _get_zipped_directory_size(saved_model_dir) |
123 | 117 |
|
124 | 118 | # Compressed model size test.
|
125 |
| - # TODO(tfmot): gzip compression can reduce file size much better. |
126 |
| - self.assertLess(compressed_size, original_size / 1.3) |
| 119 | + self.assertLess(compressed_size, original_size / 4.0) |
127 | 120 |
|
128 | 121 | def testWeightClustering_SingleLayer(self):
|
129 | 122 | number_of_clusters = 8
|
|
0 commit comments