@@ -521,27 +521,57 @@ def __repr__(self):
521
521
def dump_qconfig (func ):
522
522
def wrapper (self , * args , ** kwargs ):
523
523
result = func (self , * args , ** kwargs )
524
- create_and_dump_qconfigs (
525
- self .qpc_path ,
526
- self .onnx_path ,
527
- self .get_model_config ,
528
- [cls .__name__ for cls in self ._pytorch_transforms ],
529
- [cls .__name__ for cls in self ._onnx_transforms ],
530
- kwargs .get ("specializations" ),
531
- kwargs .get ("mdp_ts_num_devices" , 1 ),
532
- kwargs .get ("num_speculative_tokens" ),
533
- ** {
534
- k : v
535
- for k , v in kwargs .items ()
536
- if k
537
- not in ["specializations" , "mdp_ts_num_devices" , "num_speculative_tokens" , "custom_io" , "onnx_path" ]
538
- },
539
- )
524
+ try :
525
+ create_and_dump_qconfigs (
526
+ self .qpc_path ,
527
+ self .onnx_path ,
528
+ self .get_model_config ,
529
+ [cls .__name__ for cls in self ._pytorch_transforms ],
530
+ [cls .__name__ for cls in self ._onnx_transforms ],
531
+ kwargs .get ("specializations" ),
532
+ kwargs .get ("mdp_ts_num_devices" , 1 ),
533
+ kwargs .get ("num_speculative_tokens" ),
534
+ ** {
535
+ k : v
536
+ for k , v in kwargs .items ()
537
+ if k
538
+ not in ["specializations" , "mdp_ts_num_devices" , "num_speculative_tokens" , "custom_io" , "onnx_path" ]
539
+ },
540
+ )
541
+ except Exception as e :
542
+ print (f"An unexpected error occurred while dumping the qconfig: { e } " )
540
543
return result
541
544
542
545
return wrapper
543
546
544
547
548
+ def get_qaic_sdk_version (qaic_sdk_xml_path : str ) -> Optional [str ]:
549
+ """
550
+ Extracts the QAIC SDK version from the given SDK XML file.
551
+
552
+ Args:
553
+ qaic_sdk_xml_path (str): Path to the SDK XML file.
554
+ Returns:
555
+ The SDK version as a string if found, otherwise None.
556
+ """
557
+ qaic_sdk_version = None
558
+
559
+ # Check and extract version from the given SDK XML file
560
+ if os .path .exists (qaic_sdk_xml_path ):
561
+ try :
562
+ tree = ET .parse (qaic_sdk_xml_path )
563
+ root = tree .getroot ()
564
+ base_version_element = root .find (".//base_version" )
565
+ if base_version_element is not None :
566
+ qaic_sdk_version = base_version_element .text
567
+ except ET .ParseError as e :
568
+ print (f"Error parsing XML file { qaic_sdk_xml_path } : { e } " )
569
+ except Exception as e :
570
+ print (f"An unexpected error occurred while processing { qaic_sdk_xml_path } : { e } " )
571
+
572
+ return qaic_sdk_version
573
+
574
+
545
575
def create_and_dump_qconfigs (
546
576
qpc_path ,
547
577
onnx_path ,
@@ -558,29 +588,12 @@ def create_and_dump_qconfigs(
558
588
Such as huggingface configs, QEff transforms, QAIC sdk version, QNN sdk, compilation dir, qpc dir and
559
589
many other compilation options.
560
590
"""
561
- qnn_config = compiler_options ["qnn_config" ] if "qnn_config" in compiler_options else None
562
- enable_qnn = True if "qnn_config" in compiler_options else None
563
-
591
+ enable_qnn = compiler_options .get ("enable_qnn" , False )
592
+ qnn_config_path = compiler_options .get ("qnn_config" , None )
564
593
qconfig_file_path = os .path .join (os .path .dirname (qpc_path ), "qconfig.json" )
565
594
onnx_path = str (onnx_path )
566
595
specializations_file_path = str (os .path .join (os .path .dirname (qpc_path ), "specializations.json" ))
567
596
compile_dir = str (os .path .dirname (qpc_path ))
568
- qnn_config_path = (
569
- (qnn_config if qnn_config is not None else "QEfficient/compile/qnn_config.json" ) if enable_qnn else None
570
- )
571
-
572
- # Extract QAIC SDK Apps Version from SDK XML file
573
- tree = ET .parse (Constants .SDK_APPS_XML )
574
- root = tree .getroot ()
575
- qaic_version = root .find (".//base_version" ).text
576
-
577
- # Extract QNN SDK details from YAML file if the environment variable is set
578
- qnn_sdk_details = None
579
- qnn_sdk_path = os .getenv (QnnConstants .QNN_SDK_PATH_ENV_VAR_NAME )
580
- if enable_qnn and qnn_sdk_path :
581
- qnn_sdk_yaml_path = os .path .join (qnn_sdk_path , QnnConstants .QNN_SDK_YAML )
582
- with open (qnn_sdk_yaml_path , "r" ) as file :
583
- qnn_sdk_details = yaml .safe_load (file )
584
597
585
598
# Ensure all objects in the configs dictionary are JSON serializable
586
599
def make_serializable (obj ):
@@ -602,29 +615,38 @@ def make_serializable(obj):
602
615
"onnx_transforms" : make_serializable (onnx_transforms ),
603
616
"onnx_path" : onnx_path ,
604
617
},
618
+ "compiler_config" : {
619
+ "enable_qnn" : enable_qnn ,
620
+ "compile_dir" : compile_dir ,
621
+ "specializations_file_path" : specializations_file_path ,
622
+ "specializations" : make_serializable (specializations ),
623
+ "mdp_ts_num_devices" : mdp_ts_num_devices ,
624
+ "num_speculative_tokens" : num_speculative_tokens ,
625
+ ** compiler_options ,
626
+ },
627
+ "aic_sdk_config" : {
628
+ "qaic_apps_version" : get_qaic_sdk_version (Constants .SDK_APPS_XML ),
629
+ "qaic_platform_version" : get_qaic_sdk_version (Constants .SDK_PLATFORM_XML ),
630
+ },
605
631
},
606
632
}
607
633
608
- aic_compiler_config = {
609
- "apps_sdk_version" : qaic_version ,
610
- "compile_dir" : compile_dir ,
611
- "specializations_file_path" : specializations_file_path ,
612
- "specializations" : make_serializable (specializations ),
613
- "mdp_ts_num_devices" : mdp_ts_num_devices ,
614
- "num_speculative_tokens" : num_speculative_tokens ,
615
- ** compiler_options ,
616
- }
617
- qnn_config = {
618
- "enable_qnn" : enable_qnn ,
619
- "qnn_config_path" : qnn_config_path ,
620
- }
621
- # Put AIC or qnn details.
622
634
if enable_qnn :
635
+ qnn_sdk_path = os .getenv (QnnConstants .QNN_SDK_PATH_ENV_VAR_NAME )
636
+ if not qnn_sdk_path :
637
+ raise EnvironmentError (
638
+ f"QNN_SDK_PATH { qnn_sdk_path } is not set. Please set { QnnConstants .QNN_SDK_PATH_ENV_VAR_NAME } "
639
+ )
640
+ qnn_sdk_yaml_path = os .path .join (qnn_sdk_path , QnnConstants .QNN_SDK_YAML )
641
+ qnn_sdk_details = load_yaml (
642
+ qnn_sdk_yaml_path
643
+ ) # Extract QNN SDK details from YAML file if the environment variable is set
644
+ qnn_config = {
645
+ "qnn_config_path" : qnn_config_path ,
646
+ }
623
647
qconfigs ["qpc_config" ]["qnn_config" ] = qnn_config
624
648
if qnn_sdk_details :
625
649
qconfigs ["qpc_config" ]["qnn_config" ].update (qnn_sdk_details )
626
- else :
627
- qconfigs ["qpc_config" ]["aic_compiler_config" ] = aic_compiler_config
628
650
629
651
create_json (qconfig_file_path , qconfigs )
630
652
0 commit comments