Skip to content

Commit bbb7b66

Browse files
authored
Merge pull request #3650 from flairNLP/GH-3633-model-license
GH-3633: add option to add license information to Flair models
2 parents dd04993 + 8b91deb commit bbb7b66

File tree

2 files changed

+61
-18
lines changed

2 files changed

+61
-18
lines changed

flair/models/sequence_tagger_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ def __init__(
112112
self.predict_spans = self._determine_if_span_prediction_problem(self.label_dictionary)
113113

114114
self.tagset_size = len(self.label_dictionary)
115-
log.info(f"SequenceTagger predicts: {self.label_dictionary}")
115+
log.info(f"- Predicts {len(self.label_dictionary)} classes: {self.label_dictionary.get_items()[:20]}")
116116

117117
# ----- Embeddings -----
118118
self.embeddings = embeddings

flair/nn/model.py

Lines changed: 60 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,24 @@ def save(self, model_file: Union[str, Path], checkpoint: bool = False) -> None:
140140
# save model
141141
torch.save(model_state, str(model_file), pickle_protocol=4)
142142

143+
@property
144+
def license_info(self) -> str:
145+
"""Get the license information for this model."""
146+
if self.model_card is None:
147+
return "No license information available"
148+
return self.model_card.get("license_info", "No license information available")
149+
150+
@license_info.setter
151+
def license_info(self, value: Optional[str]):
152+
"""Set the license information for this model."""
153+
if self.model_card is None:
154+
self.model_card = {}
155+
if value is None:
156+
# Remove license info if it exists
157+
self.model_card.pop("license_info", None)
158+
else:
159+
self.model_card["license_info"] = value
160+
143161
@classmethod
144162
def load(cls, model_path: Union[str, Path, dict[str, Any]]) -> "Model":
145163
"""Loads a Flair model from the given file or state dictionary.
@@ -211,10 +229,21 @@ def load(cls, model_path: Union[str, Path, dict[str, Any]]) -> "Model":
211229
if "__cls__" in state:
212230
state.pop("__cls__")
213231

232+
log.info("--------------------------------------------------")
233+
log.info(f"- Loading {cls.__name__}")
234+
214235
model = cls._init_model_with_state_dict(state)
215236

216-
if "model_card" in state:
217-
model.model_card = state["model_card"]
237+
# Print license information
238+
log.info("--------------------------------------------------")
239+
model_card = state.get("model_card", None)
240+
if model_card is not None:
241+
model.model_card = model_card
242+
license_info = model_card.get("license_info", "No license information available")
243+
log.info(f"- Model license: {license_info}")
244+
else:
245+
log.info("- Model license: No license information available")
246+
log.info("--------------------------------------------------")
218247

219248
model.eval()
220249
model.to(flair.device)
@@ -229,25 +258,39 @@ def print_model_card(self):
229258
230259
Only available for models trained with with Flair >= 0.9.1.
231260
"""
232-
if hasattr(self, "model_card"):
261+
model_card = getattr(self, "model_card", None) # Returns None if attribute doesn't exist or is None
262+
263+
if model_card is not None:
233264
param_out = "\n------------------------------------\n"
234265
param_out += "--------- Flair Model Card ---------\n"
235266
param_out += "------------------------------------\n"
236-
param_out += "- this Flair model was trained with:\n"
237-
param_out += f"-- Flair version {self.model_card['flair_version']}\n"
238-
param_out += f"-- PyTorch version {self.model_card['pytorch_version']}\n"
239-
if "transformers_version" in self.model_card:
240-
param_out += f"-- Transformers version {self.model_card['transformers_version']}\n"
241-
param_out += "------------------------------------\n"
242267

243-
param_out += "------- Training Parameters: -------\n"
244-
param_out += "------------------------------------\n"
245-
training_params = "\n".join(
246-
f'-- {param} = {self.model_card["training_parameters"][param]}'
247-
for param in self.model_card["training_parameters"]
248-
)
249-
param_out += training_params + "\n"
250-
param_out += "------------------------------------\n"
268+
# Only print version information if it exists
269+
if any(key in model_card for key in ["flair_version", "pytorch_version", "transformers_version"]):
270+
param_out += "- this Flair model was trained with:\n"
271+
if "flair_version" in model_card:
272+
param_out += f"-- Flair version {model_card['flair_version']}\n"
273+
if "pytorch_version" in model_card:
274+
param_out += f"-- PyTorch version {model_card['pytorch_version']}\n"
275+
if "transformers_version" in model_card:
276+
param_out += f"-- Transformers version {model_card['transformers_version']}\n"
277+
param_out += "------------------------------------\n"
278+
279+
# Print license info if it exists
280+
if "license_info" in model_card:
281+
param_out += f"-- License: {model_card['license_info']}\n"
282+
param_out += "------------------------------------\n"
283+
284+
# Print training parameters if they exist
285+
if "training_parameters" in model_card:
286+
param_out += "------- Training Parameters: -------\n"
287+
param_out += "------------------------------------\n"
288+
training_params = "\n".join(
289+
f'-- {param} = {model_card["training_parameters"][param]}'
290+
for param in model_card["training_parameters"]
291+
)
292+
param_out += training_params + "\n"
293+
param_out += "------------------------------------\n"
251294

252295
log.info(param_out)
253296
else:

0 commit comments

Comments
 (0)