|
1 | 1 | from slips_files.common.imports import * |
| 2 | +import hmac |
| 3 | +import hashlib |
2 | 4 | from sklearn.linear_model import SGDClassifier |
3 | 5 | from sklearn.preprocessing import StandardScaler |
4 | 6 | import pickle |
@@ -332,18 +334,35 @@ def store_model(self): |
332 | 334 | with open('./modules/flowmldetection/scaler.bin', 'wb') as g: |
333 | 335 | data = pickle.dumps(self.scaler) |
334 | 336 | g.write(data) |
335 | | - |
| 337 | + |
| 338 | + class SecureUnpickler(pickle.Unpickler): |
| 339 | + @classmethod |
| 340 | + def loads(cls, data, key): |
| 341 | + # Check HMAC for data integrity |
| 342 | + hmac_new, data = data[:64], data[64:] |
| 343 | + if hmac.new(key, data, hashlib.sha256).hexdigest().encode() != hmac_new: |
| 344 | + raise ValueError("Data integrity check failed") |
| 345 | + return cls.loads(data) |
| 346 | + |
| 347 | + def load(self): |
| 348 | + raise NotImplementedError("Use SecureUnpickler.loads() instead") |
| 349 | + |
| 350 | + def loads(self, data): |
| 351 | + return super().loads(data) |
| 352 | + |
336 | 353 | def read_model(self): |
337 | 354 | """ |
338 | 355 | Read the trained model from disk |
339 | 356 | """ |
340 | 357 | try: |
341 | 358 | self.print('Reading the trained model from disk.', 0, 2) |
342 | 359 | with open('./modules/flowmldetection/model.bin', 'rb') as f: |
343 | | - self.clf = pickle.load(f) |
| 360 | + data = f.read() |
| 361 | + self.clf = SecureUnpickler.loads(data, b'my_secret_key') |
344 | 362 | self.print('Reading the trained scaler from disk.', 0, 2) |
345 | 363 | with open('./modules/flowmldetection/scaler.bin', 'rb') as g: |
346 | | - self.scaler = pickle.load(g) |
| 364 | + data = g.read() |
| 365 | + self.scaler = SecureUnpickler.loads(data, b'my_secret_key') |
347 | 366 | except FileNotFoundError: |
348 | 367 | # If there is no model, create one empty |
349 | 368 | self.print('There was no model. Creating a new empty model.', 0, 2) |
|
0 commit comments