@@ -140,6 +140,24 @@ def save(self, model_file: Union[str, Path], checkpoint: bool = False) -> None:
140
140
# save model
141
141
torch .save (model_state , str (model_file ), pickle_protocol = 4 )
142
142
143
+ @property
144
+ def license_info (self ) -> str :
145
+ """Get the license information for this model."""
146
+ if self .model_card is None :
147
+ return "No license information available"
148
+ return self .model_card .get ("license_info" , "No license information available" )
149
+
150
+ @license_info .setter
151
+ def license_info (self , value : Optional [str ]):
152
+ """Set the license information for this model."""
153
+ if self .model_card is None :
154
+ self .model_card = {}
155
+ if value is None :
156
+ # Remove license info if it exists
157
+ self .model_card .pop ("license_info" , None )
158
+ else :
159
+ self .model_card ["license_info" ] = value
160
+
143
161
@classmethod
144
162
def load (cls , model_path : Union [str , Path , dict [str , Any ]]) -> "Model" :
145
163
"""Loads a Flair model from the given file or state dictionary.
@@ -211,10 +229,21 @@ def load(cls, model_path: Union[str, Path, dict[str, Any]]) -> "Model":
211
229
if "__cls__" in state :
212
230
state .pop ("__cls__" )
213
231
232
+ log .info ("--------------------------------------------------" )
233
+ log .info (f"- Loading { cls .__name__ } " )
234
+
214
235
model = cls ._init_model_with_state_dict (state )
215
236
216
- if "model_card" in state :
217
- model .model_card = state ["model_card" ]
237
+ # Print license information
238
+ log .info ("--------------------------------------------------" )
239
+ model_card = state .get ("model_card" , None )
240
+ if model_card is not None :
241
+ model .model_card = model_card
242
+ license_info = model_card .get ("license_info" , "No license information available" )
243
+ log .info (f"- Model license: { license_info } " )
244
+ else :
245
+ log .info ("- Model license: No license information available" )
246
+ log .info ("--------------------------------------------------" )
218
247
219
248
model .eval ()
220
249
model .to (flair .device )
@@ -229,25 +258,39 @@ def print_model_card(self):
229
258
230
259
Only available for models trained with with Flair >= 0.9.1.
231
260
"""
232
- if hasattr (self , "model_card" ):
261
+ model_card = getattr (self , "model_card" , None ) # Returns None if attribute doesn't exist or is None
262
+
263
+ if model_card is not None :
233
264
param_out = "\n ------------------------------------\n "
234
265
param_out += "--------- Flair Model Card ---------\n "
235
266
param_out += "------------------------------------\n "
236
- param_out += "- this Flair model was trained with:\n "
237
- param_out += f"-- Flair version { self .model_card ['flair_version' ]} \n "
238
- param_out += f"-- PyTorch version { self .model_card ['pytorch_version' ]} \n "
239
- if "transformers_version" in self .model_card :
240
- param_out += f"-- Transformers version { self .model_card ['transformers_version' ]} \n "
241
- param_out += "------------------------------------\n "
242
267
243
- param_out += "------- Training Parameters: -------\n "
244
- param_out += "------------------------------------\n "
245
- training_params = "\n " .join (
246
- f'-- { param } = { self .model_card ["training_parameters" ][param ]} '
247
- for param in self .model_card ["training_parameters" ]
248
- )
249
- param_out += training_params + "\n "
250
- param_out += "------------------------------------\n "
268
+ # Only print version information if it exists
269
+ if any (key in model_card for key in ["flair_version" , "pytorch_version" , "transformers_version" ]):
270
+ param_out += "- this Flair model was trained with:\n "
271
+ if "flair_version" in model_card :
272
+ param_out += f"-- Flair version { model_card ['flair_version' ]} \n "
273
+ if "pytorch_version" in model_card :
274
+ param_out += f"-- PyTorch version { model_card ['pytorch_version' ]} \n "
275
+ if "transformers_version" in model_card :
276
+ param_out += f"-- Transformers version { model_card ['transformers_version' ]} \n "
277
+ param_out += "------------------------------------\n "
278
+
279
+ # Print license info if it exists
280
+ if "license_info" in model_card :
281
+ param_out += f"-- License: { model_card ['license_info' ]} \n "
282
+ param_out += "------------------------------------\n "
283
+
284
+ # Print training parameters if they exist
285
+ if "training_parameters" in model_card :
286
+ param_out += "------- Training Parameters: -------\n "
287
+ param_out += "------------------------------------\n "
288
+ training_params = "\n " .join (
289
+ f'-- { param } = { model_card ["training_parameters" ][param ]} '
290
+ for param in model_card ["training_parameters" ]
291
+ )
292
+ param_out += training_params + "\n "
293
+ param_out += "------------------------------------\n "
251
294
252
295
log .info (param_out )
253
296
else :
0 commit comments