Skip to content

Commit 0febcae

Browse files
committed
Fixed API request model and subsequent model hosting in validator ML process.
1 parent 4d67c73 commit 0febcae

File tree

3 files changed

+147
-67
lines changed

3 files changed

+147
-67
lines changed

tensorlink/api/node.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,7 @@ def request_model(job_request: JobRequest, request: Request):
142142
# Trigger the loading process
143143
job_data = {
144144
"author": self.smart_node.rsa_key_hash,
145+
"api": True,
145146
"active": True,
146147
"hosted": True,
147148
"training": False,

tensorlink/ml/validator.py

Lines changed: 146 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -254,9 +254,6 @@ def _get_popular_models(self) -> list:
254254
def _manage_auto_loaded_models(self):
255255
"""Manage auto-loaded models based on popularity from JSON cache, falling back to DEFAULT_MODELS"""
256256
popular_models = self._get_popular_models()
257-
258-
# If no popular models tracked yet, use DEFAULT_MODELS as fallback
259-
models_to_load = DEFAULT_MODELS[: self.MAX_AUTO_MODELS]
260257
# if not popular_models:
261258
# models_to_load = DEFAULT_MODELS[: self.MAX_AUTO_MODELS]
262259
# else:
@@ -266,6 +263,9 @@ def _manage_auto_loaded_models(self):
266263
# (f"Loading popular models: {models_to_load}", "blue", logging.INFO),
267264
# )
268265

266+
# If no popular models tracked yet, use DEFAULT_MODELS as fallback
267+
models_to_load = DEFAULT_MODELS[: self.MAX_AUTO_MODELS]
268+
269269
# Load models up to the limit
270270
for model_name in models_to_load:
271271
if (
@@ -279,30 +279,9 @@ def _manage_auto_loaded_models(self):
279279
self.models_initializing.add(model_name)
280280
self._initialize_hosted_job(model_name)
281281

282-
# Continue initialization for models that are in progress
283-
for model_name in list(self.models_initializing):
284-
if model_name in models_to_load: # Still wanted
285-
# Try second initialization call
286-
self._initialize_hosted_job(model_name)
287-
# Check if initialization is complete
288-
if model_name in self.models and isinstance(
289-
self.models[model_name], str
290-
):
291-
# Model is fully initialized (module_id is now a string)
292-
self.models_initializing.discard(model_name)
293-
self.send_request(
294-
"debug_print",
295-
(
296-
f"Completed auto-loading model: {model_name}",
297-
"green",
298-
logging.INFO,
299-
),
300-
)
301-
else:
302-
# Model no longer wanted, cancel initialization
303-
self.models_initializing.discard(model_name)
304-
if model_name in self.models:
305-
self._remove_hosted_job(model_name)
282+
# Try to finalize models that are initializing
283+
if self.models_initializing:
284+
self._try_finalize_initializing_models()
306285

307286
# Remove models not in the current priority list
308287
currently_loaded = [
@@ -314,15 +293,17 @@ def _manage_auto_loaded_models(self):
314293
model_name, days=1
315294
) # Check last day
316295
if recent_requests < 5: # Low recent activity
317-
self.send_request(
318-
"debug_print",
319-
(
320-
f"Removing unpopular model: {model_name}",
321-
"yellow",
322-
logging.INFO,
323-
),
324-
)
325-
self._remove_hosted_job(model_name)
296+
is_active = self.send_request("check_job", (model_name,))
297+
if not is_active:
298+
self.send_request(
299+
"debug_print",
300+
(
301+
f"Removing unpopular model: {model_name}",
302+
"yellow",
303+
logging.INFO,
304+
),
305+
)
306+
self._remove_hosted_job(model_name)
326307

327308
def inspect_model(self, model_name: str, job_data: dict = None):
328309
"""Inspect a model to determine network requirements and store distribution in JSON cache"""
@@ -391,8 +372,29 @@ def check_node(self):
391372
job_data = self.send_request("get_jobs", None)
392373

393374
if isinstance(job_data, dict):
394-
# Offload model inspection to a background thread to avoid blocking
395-
self.inspect_model(job_data.get("model_name"), job_data)
375+
model_name = job_data.get("model_name")
376+
377+
if job_data.get("api"):
378+
payment = job_data.get("payment", 0)
379+
time_limit = job_data.get("time", 1800)
380+
381+
# Initialize if not already done
382+
if (
383+
model_name not in self.models
384+
and model_name not in self.models_initializing
385+
):
386+
self.models_initializing.add(model_name)
387+
self._initialize_hosted_job(
388+
model_name, payment=payment, time_limit=time_limit
389+
)
390+
391+
# Try to finalize if already initializing
392+
if model_name in self.models_initializing:
393+
self._finalize_hosted_job(model_name)
394+
395+
else:
396+
# If request via user node, begin the model reqs inspection for the job request
397+
self.inspect_model(model_name, job_data)
396398

397399
# Check for inference generate calls
398400
for model_name, module_id in self.models.items():
@@ -517,48 +519,57 @@ def _handle_generate_request(self, request: GenerationRequest):
517519
# Decode generated tokens
518520
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
519521

520-
# Extract only the assistant's response from the generated text
521-
clean_response = extract_assistant_response(generated_text, request.hf_name)
522-
request.output = clean_response
522+
# Many models echo the prompt, so remove it
523+
if generated_text.startswith(formatted_prompt):
524+
request.output = generated_text[len(formatted_prompt) :].strip()
525+
else:
526+
request.output = generated_text
523527

524528
# Return the clean response
525529
self.send_request("update_api_request", (request,))
526530

527-
def _initialize_hosted_job(self, model_name: str):
528-
"""Method that can be invoked twice, once to begin setup of the job, and a second
529-
time to finalize the job init."""
530-
args = self.send_request("check_module", None)
531+
def _try_finalize_initializing_models(self):
532+
"""Attempt to finalize all models that are currently initializing."""
533+
for model_name in list(self.models_initializing):
534+
if self._finalize_hosted_job(model_name):
535+
self.send_request(
536+
"debug_print",
537+
(
538+
f"Successfully finalized model: {model_name}",
539+
"green",
540+
logging.INFO,
541+
),
542+
)
531543

532-
# Check if the model loading is complete across workers and ready to go (second call)
533-
if model_name in self.models and args:
534-
if isinstance(args, tuple):
535-
(
536-
file_name,
537-
module_id,
538-
distribution,
539-
module_name,
540-
optimizer_name,
541-
training,
542-
) = args
543-
self.modules[module_id] = self.models.pop(model_name)
544-
self.models[model_name] = module_id
545-
self.modules[module_id].distribute_model(distribution)
546-
self.tokenizers[model_name] = AutoTokenizer.from_pretrained(model_name)
547-
548-
# If not, check if we can spin up the model (first call)
549-
else:
550-
# Small init sleep time
544+
def _initialize_hosted_job(
545+
self, model_name: str, payment: int = 0, time_limit: int = None
546+
):
547+
"""Initialize a hosted job by creating the distributed model and submitting inspection request."""
548+
try:
549+
# Check if already initialized
551550
if model_name in self.models:
552-
time.sleep(20)
551+
self.send_request(
552+
"debug_print",
553+
(
554+
f"Model {model_name} already initializing, skipping duplicate init",
555+
"yellow",
556+
logging.DEBUG,
557+
),
558+
)
559+
return
553560

561+
# Create distributed model instance
554562
distributed_model = DistributedModel(model_name, node=self.node)
555563
self.models[model_name] = distributed_model
564+
565+
# Prepare job data for inspection
556566
job_data = {
557567
"author": None,
558568
"active": True,
559569
"hosted": True,
560570
"training": False,
561-
"payment": 0,
571+
"payment": payment,
572+
"time": time_limit,
562573
"capacity": 0,
563574
"n_pipelines": 1,
564575
"dp_factor": 1,
@@ -567,8 +578,77 @@ def _initialize_hosted_job(self, model_name: str):
567578
"model_name": model_name,
568579
"seed_validators": [],
569580
}
581+
582+
# Inspect model to determine network requirements
570583
self.inspect_model(model_name, job_data)
571584

585+
self.send_request(
586+
"debug_print",
587+
(f"Initialized hosted job for {model_name}", "green", logging.INFO),
588+
)
589+
590+
except Exception as e:
591+
logging.error(f"Error initializing hosted job for {model_name}: {str(e)}")
592+
self.models_initializing.discard(model_name)
593+
if model_name in self.models:
594+
del self.models[model_name]
595+
596+
def _finalize_hosted_job(self, model_name: str):
597+
"""Finalize a hosted job by setting up the distributed model with workers."""
598+
try:
599+
# Check if we have module info ready
600+
args = self.send_request("check_module", None)
601+
602+
if not args or not isinstance(args, tuple):
603+
# Module not ready yet
604+
return False
605+
606+
# Check if model is in initialization state
607+
if model_name not in self.models:
608+
return False
609+
610+
# Unpack module information
611+
(
612+
file_name,
613+
module_id,
614+
distribution,
615+
module_name,
616+
optimizer_name,
617+
training,
618+
) = args
619+
620+
# Move from initialization to active state
621+
distributed_model = self.models.pop(model_name)
622+
self.modules[module_id] = distributed_model
623+
self.models[model_name] = module_id
624+
625+
# Distribute the model across workers
626+
self.modules[module_id].distribute_model(distribution)
627+
628+
# Load tokenizer
629+
self.tokenizers[model_name] = AutoTokenizer.from_pretrained(model_name)
630+
631+
# Remove from initializing set
632+
self.models_initializing.discard(model_name)
633+
634+
self.send_request(
635+
"debug_print",
636+
(
637+
f"Finalized hosted job for {model_name} with module_id {module_id}",
638+
"green",
639+
logging.INFO,
640+
),
641+
)
642+
643+
return True
644+
645+
except Exception as e:
646+
logging.error(f"Error finalizing hosted job for {model_name}: {str(e)}")
647+
self.models_initializing.discard(model_name)
648+
if model_name in self.models:
649+
del self.models[model_name]
650+
return False
651+
572652
def _remove_hosted_job(self, model_name: str):
573653
"""Remove a hosted job and clean up all associated resources"""
574654
try:

tensorlink/nodes/validator.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -828,7 +828,6 @@ def distribute_job(self):
828828
"""Distribute job to a few other non-seed validators"""
829829
for validator in self.validators:
830830
pass
831-
pass
832831

833832
# # Query job information from seed validators
834833
# job_responses = [

0 commit comments

Comments
 (0)