Skip to content

Commit 5aa9c00

Browse files
authored
Merge pull request #83 from smartnodes-lab/testnet-v2
Added expected model output for HF models
2 parents f3522fd + 4d67c73 commit 5aa9c00

File tree

5 files changed

+45
-22
lines changed

5 files changed

+45
-22
lines changed

tensorlink/ml/module.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -734,7 +734,11 @@ def _initialize_distribution(self):
734734
module["training"] = self.training
735735

736736
else:
737-
distribution = {"model_name": self.model}
737+
distribution = {
738+
"model_name": self.model,
739+
"training": self.training,
740+
"optimizer": optimizer_type,
741+
}
738742

739743
# Request job from network
740744
distributed_config = self.node.send_request(

tensorlink/ml/utils.py

Lines changed: 34 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -873,25 +873,42 @@ def combine_micro_batches(micro_batches):
873873
def replace_output_with_custom_grad(combined_output, custom_grad_output):
874874
"""
875875
Replace the main output tensor (logits, last_hidden_state, etc.) in the combined_output
876-
with the custom_grad_output, preserving the original structure.
876+
with the custom_grad_output, preserving structure and returning a ModelOutput when possible.
877877
"""
878-
if hasattr(combined_output, "logits"):
879-
return combined_output.__class__(
880-
**{**combined_output, "logits": custom_grad_output}
881-
)
882-
elif hasattr(combined_output, "last_hidden_state"):
883-
return combined_output.__class__(
884-
**{**combined_output, "last_hidden_state": custom_grad_output}
885-
)
886-
elif isinstance(combined_output, torch.Tensor):
878+
# If the combined output is already a tensor
879+
if isinstance(combined_output, torch.Tensor):
887880
return custom_grad_output
888-
else:
889-
# For custom ModelOutput-like structures, replace the first tensor found
890-
for key, value in combined_output.items():
891-
if isinstance(value, torch.Tensor):
892-
combined_output[key] = custom_grad_output
893-
break
894-
return combined_output
881+
882+
# Handle ModelOutput subclasses (SequenceClassifierOutput, etc.)
883+
if isinstance(combined_output, ModelOutput):
884+
data = combined_output.to_dict()
885+
if "logits" in data:
886+
data["logits"] = custom_grad_output
887+
elif "last_hidden_state" in data:
888+
data["last_hidden_state"] = custom_grad_output
889+
else:
890+
for k, v in data.items():
891+
if isinstance(v, torch.Tensor):
892+
data[k] = custom_grad_output
893+
break
894+
return combined_output.__class__(**data)
895+
896+
# Handle dict outputs
897+
if isinstance(combined_output, dict):
898+
new_output = dict(combined_output)
899+
if "logits" in new_output:
900+
new_output["logits"] = custom_grad_output
901+
elif "last_hidden_state" in new_output:
902+
new_output["last_hidden_state"] = custom_grad_output
903+
else:
904+
for k, v in new_output.items():
905+
if isinstance(v, torch.Tensor):
906+
new_output[k] = custom_grad_output
907+
break
908+
# Wrap dict in a generic ModelOutput for consistency
909+
return ModelOutput(**new_output)
910+
911+
raise TypeError(f"Unsupported output type: {type(combined_output)}")
895912

896913

897914
def split_into_micro_batches(combined_output, n_micro_batch):

tensorlink/nodes/user.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -293,7 +293,7 @@ def request_job(self, n_pipelines, dp_factor, distribution, training):
293293
# self.debug_print("request_job: Job requested on Smart Contract!")
294294
# validator_ids = self.contract.functions.getJobValidators(job_id).call()
295295
validator_ids = [random.choice(self.validators)]
296-
if len(distribution) != 1:
296+
if not distribution.get("model_name"):
297297
# The case where we have a custom model with distributed config
298298
distribution = {
299299
k: v for k, v in distribution.items() if v["type"] == "offloaded"
@@ -322,6 +322,7 @@ def request_job(self, n_pipelines, dp_factor, distribution, training):
322322
}
323323
else:
324324
# The case where we have a huggingface model name for inference
325+
optimizer_type = distribution.get("optimizer")
325326
job_request = {
326327
"author": self.rsa_key_hash,
327328
"active": True,
@@ -334,6 +335,7 @@ def request_job(self, n_pipelines, dp_factor, distribution, training):
334335
"distribution": {},
335336
"n_workers": 0,
336337
"model_name": distribution.get("model_name"),
338+
"optimizer": f"{optimizer_type.__module__}.{optimizer_type.__name__}",
337339
"seed_validators": validator_ids,
338340
}
339341

tensorlink/nodes/validator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -355,7 +355,7 @@ def create_hf_job(self, job_info: dict, requesters_ip: str = None):
355355

356356
# Huggingface model info checks
357357
(vram, ram) = estimate_hf_model_memory(
358-
job_info.get("model_name"), training=False
358+
job_info.get("model_name"), training=job_info.get("training", False)
359359
)
360360

361361
if job_info.get("payment", 0) == 0:
@@ -367,6 +367,7 @@ def create_hf_job(self, job_info: dict, requesters_ip: str = None):
367367
job_data["ram"] = ram
368368
job_data["vram"] = vram
369369
job_data["time"] = _time
370+
370371
# Hand off model dissection and worker assignment to DistributedValidator process
371372
request_value = "HF-JOB-REQ" + json.dumps(job_data)
372373
self._store_request(self.rsa_key_hash, request_value)
@@ -527,7 +528,6 @@ def check_job_availability(self, job_data: dict):
527528
tag="Validator",
528529
)
529530
return False
530-
531531
return assigned_workers
532532

533533
def create_base_job(self, job_data: dict):

tensorlink/p2p/torch_node.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ def _handle_optimizer_response(self, data: bytes, node: Connection):
157157
node.ghosts += 1
158158
return False
159159
else:
160-
module_id, response_type = json.dumps(data[18:]).encode()
160+
module_id, response_type = json.loads(data[18:])
161161

162162
if response_type == "loaded":
163163
self.debug_print(

0 commit comments

Comments
 (0)