@@ -254,9 +254,6 @@ def _get_popular_models(self) -> list:
254254 def _manage_auto_loaded_models (self ):
255255 """Manage auto-loaded models based on popularity from JSON cache, falling back to DEFAULT_MODELS"""
256256 popular_models = self ._get_popular_models ()
257-
258- # If no popular models tracked yet, use DEFAULT_MODELS as fallback
259- models_to_load = DEFAULT_MODELS [: self .MAX_AUTO_MODELS ]
260257 # if not popular_models:
261258 # models_to_load = DEFAULT_MODELS[: self.MAX_AUTO_MODELS]
262259 # else:
@@ -266,6 +263,9 @@ def _manage_auto_loaded_models(self):
266263 # (f"Loading popular models: {models_to_load}", "blue", logging.INFO),
267264 # )
268265
266+ # If no popular models tracked yet, use DEFAULT_MODELS as fallback
267+ models_to_load = DEFAULT_MODELS [: self .MAX_AUTO_MODELS ]
268+
269269 # Load models up to the limit
270270 for model_name in models_to_load :
271271 if (
@@ -279,30 +279,9 @@ def _manage_auto_loaded_models(self):
279279 self .models_initializing .add (model_name )
280280 self ._initialize_hosted_job (model_name )
281281
282- # Continue initialization for models that are in progress
283- for model_name in list (self .models_initializing ):
284- if model_name in models_to_load : # Still wanted
285- # Try second initialization call
286- self ._initialize_hosted_job (model_name )
287- # Check if initialization is complete
288- if model_name in self .models and isinstance (
289- self .models [model_name ], str
290- ):
291- # Model is fully initialized (module_id is now a string)
292- self .models_initializing .discard (model_name )
293- self .send_request (
294- "debug_print" ,
295- (
296- f"Completed auto-loading model: { model_name } " ,
297- "green" ,
298- logging .INFO ,
299- ),
300- )
301- else :
302- # Model no longer wanted, cancel initialization
303- self .models_initializing .discard (model_name )
304- if model_name in self .models :
305- self ._remove_hosted_job (model_name )
282+ # Try to finalize models that are initializing
283+ if self .models_initializing :
284+ self ._try_finalize_initializing_models ()
306285
307286 # Remove models not in the current priority list
308287 currently_loaded = [
@@ -314,15 +293,17 @@ def _manage_auto_loaded_models(self):
314293 model_name , days = 1
315294 ) # Check last day
316295 if recent_requests < 5 : # Low recent activity
317- self .send_request (
318- "debug_print" ,
319- (
320- f"Removing unpopular model: { model_name } " ,
321- "yellow" ,
322- logging .INFO ,
323- ),
324- )
325- self ._remove_hosted_job (model_name )
296+ is_active = self .send_request ("check_job" , (model_name ,))
297+ if not is_active :
298+ self .send_request (
299+ "debug_print" ,
300+ (
301+ f"Removing unpopular model: { model_name } " ,
302+ "yellow" ,
303+ logging .INFO ,
304+ ),
305+ )
306+ self ._remove_hosted_job (model_name )
326307
327308 def inspect_model (self , model_name : str , job_data : dict = None ):
328309 """Inspect a model to determine network requirements and store distribution in JSON cache"""
@@ -391,8 +372,29 @@ def check_node(self):
391372 job_data = self .send_request ("get_jobs" , None )
392373
393374 if isinstance (job_data , dict ):
394- # Offload model inspection to a background thread to avoid blocking
395- self .inspect_model (job_data .get ("model_name" ), job_data )
375+ model_name = job_data .get ("model_name" )
376+
377+ if job_data .get ("api" ):
378+ payment = job_data .get ("payment" , 0 )
379+ time_limit = job_data .get ("time" , 1800 )
380+
381+ # Initialize if not already done
382+ if (
383+ model_name not in self .models
384+ and model_name not in self .models_initializing
385+ ):
386+ self .models_initializing .add (model_name )
387+ self ._initialize_hosted_job (
388+ model_name , payment = payment , time_limit = time_limit
389+ )
390+
391+ # Try to finalize if already initializing
392+ if model_name in self .models_initializing :
393+ self ._finalize_hosted_job (model_name )
394+
395+ else :
396+ # If request via user node, begin the model reqs inspection for the job request
397+ self .inspect_model (model_name , job_data )
396398
397399 # Check for inference generate calls
398400 for model_name , module_id in self .models .items ():
@@ -517,48 +519,57 @@ def _handle_generate_request(self, request: GenerationRequest):
517519 # Decode generated tokens
518520 generated_text = tokenizer .decode (outputs [0 ], skip_special_tokens = True )
519521
520- # Extract only the assistant's response from the generated text
521- clean_response = extract_assistant_response (generated_text , request .hf_name )
522- request .output = clean_response
522+ # Many models echo the prompt, so remove it
523+ if generated_text .startswith (formatted_prompt ):
524+ request .output = generated_text [len (formatted_prompt ) :].strip ()
525+ else :
526+ request .output = generated_text
523527
524528 # Return the clean response
525529 self .send_request ("update_api_request" , (request ,))
526530
527- def _initialize_hosted_job (self , model_name : str ):
528- """Method that can be invoked twice, once to begin setup of the job, and a second
529- time to finalize the job init."""
530- args = self .send_request ("check_module" , None )
531+ def _try_finalize_initializing_models (self ):
532+ """Attempt to finalize all models that are currently initializing."""
533+ for model_name in list (self .models_initializing ):
534+ if self ._finalize_hosted_job (model_name ):
535+ self .send_request (
536+ "debug_print" ,
537+ (
538+ f"Successfully finalized model: { model_name } " ,
539+ "green" ,
540+ logging .INFO ,
541+ ),
542+ )
531543
532- # Check if the model loading is complete across workers and ready to go (second call)
533- if model_name in self .models and args :
534- if isinstance (args , tuple ):
535- (
536- file_name ,
537- module_id ,
538- distribution ,
539- module_name ,
540- optimizer_name ,
541- training ,
542- ) = args
543- self .modules [module_id ] = self .models .pop (model_name )
544- self .models [model_name ] = module_id
545- self .modules [module_id ].distribute_model (distribution )
546- self .tokenizers [model_name ] = AutoTokenizer .from_pretrained (model_name )
547-
548- # If not, check if we can spin up the model (first call)
549- else :
550- # Small init sleep time
544+ def _initialize_hosted_job (
545+ self , model_name : str , payment : int = 0 , time_limit : int = None
546+ ):
547+ """Initialize a hosted job by creating the distributed model and submitting inspection request."""
548+ try :
549+ # Check if already initialized
551550 if model_name in self .models :
552- time .sleep (20 )
551+ self .send_request (
552+ "debug_print" ,
553+ (
554+ f"Model { model_name } already initializing, skipping duplicate init" ,
555+ "yellow" ,
556+ logging .DEBUG ,
557+ ),
558+ )
559+ return
553560
561+ # Create distributed model instance
554562 distributed_model = DistributedModel (model_name , node = self .node )
555563 self .models [model_name ] = distributed_model
564+
565+ # Prepare job data for inspection
556566 job_data = {
557567 "author" : None ,
558568 "active" : True ,
559569 "hosted" : True ,
560570 "training" : False ,
561- "payment" : 0 ,
571+ "payment" : payment ,
572+ "time" : time_limit ,
562573 "capacity" : 0 ,
563574 "n_pipelines" : 1 ,
564575 "dp_factor" : 1 ,
@@ -567,8 +578,77 @@ def _initialize_hosted_job(self, model_name: str):
567578 "model_name" : model_name ,
568579 "seed_validators" : [],
569580 }
581+
582+ # Inspect model to determine network requirements
570583 self .inspect_model (model_name , job_data )
571584
585+ self .send_request (
586+ "debug_print" ,
587+ (f"Initialized hosted job for { model_name } " , "green" , logging .INFO ),
588+ )
589+
590+ except Exception as e :
591+ logging .error (f"Error initializing hosted job for { model_name } : { str (e )} " )
592+ self .models_initializing .discard (model_name )
593+ if model_name in self .models :
594+ del self .models [model_name ]
595+
596+ def _finalize_hosted_job (self , model_name : str ):
597+ """Finalize a hosted job by setting up the distributed model with workers."""
598+ try :
599+ # Check if we have module info ready
600+ args = self .send_request ("check_module" , None )
601+
602+ if not args or not isinstance (args , tuple ):
603+ # Module not ready yet
604+ return False
605+
606+ # Check if model is in initialization state
607+ if model_name not in self .models :
608+ return False
609+
610+ # Unpack module information
611+ (
612+ file_name ,
613+ module_id ,
614+ distribution ,
615+ module_name ,
616+ optimizer_name ,
617+ training ,
618+ ) = args
619+
620+ # Move from initialization to active state
621+ distributed_model = self .models .pop (model_name )
622+ self .modules [module_id ] = distributed_model
623+ self .models [model_name ] = module_id
624+
625+ # Distribute the model across workers
626+ self .modules [module_id ].distribute_model (distribution )
627+
628+ # Load tokenizer
629+ self .tokenizers [model_name ] = AutoTokenizer .from_pretrained (model_name )
630+
631+ # Remove from initializing set
632+ self .models_initializing .discard (model_name )
633+
634+ self .send_request (
635+ "debug_print" ,
636+ (
637+ f"Finalized hosted job for { model_name } with module_id { module_id } " ,
638+ "green" ,
639+ logging .INFO ,
640+ ),
641+ )
642+
643+ return True
644+
645+ except Exception as e :
646+ logging .error (f"Error finalizing hosted job for { model_name } : { str (e )} " )
647+ self .models_initializing .discard (model_name )
648+ if model_name in self .models :
649+ del self .models [model_name ]
650+ return False
651+
572652 def _remove_hosted_job (self , model_name : str ):
573653 """Remove a hosted job and clean up all associated resources"""
574654 try :
0 commit comments