9191
9292
9393def get_model (
94- model_id : str , revision : Optional [str ], sharded : bool , quantize : Optional [str ]
94+ model_id : str ,
95+ revision : Optional [str ],
96+ sharded : bool ,
97+ quantize : Optional [str ],
98+ trust_remote_code : bool ,
9599) -> Model :
96100 if "facebook/galactica" in model_id :
97101 if sharded :
98- return GalacticaSharded (model_id , revision , quantize = quantize )
102+ return GalacticaSharded (
103+ model_id ,
104+ revision ,
105+ quantize = quantize ,
106+ trust_remote_code = trust_remote_code ,
107+ )
99108 else :
100- return Galactica (model_id , revision , quantize = quantize )
109+ return Galactica (
110+ model_id ,
111+ revision ,
112+ quantize = quantize ,
113+ trust_remote_code = trust_remote_code ,
114+ )
101115
102116 if model_id .startswith ("bigcode/" ):
103117 if sharded :
104118 if not FLASH_ATTENTION :
105119 raise NotImplementedError (
106120 FLASH_ATT_ERROR_MESSAGE .format (f"Sharded Santacoder" )
107121 )
108- return FlashSantacoderSharded (model_id , revision , quantize = quantize )
122+ return FlashSantacoderSharded (
123+ model_id ,
124+ revision ,
125+ quantize = quantize ,
126+ trust_remote_code = trust_remote_code ,
127+ )
109128 else :
110129 santacoder_cls = FlashSantacoder if FLASH_ATTENTION else SantaCoder
111- return santacoder_cls (model_id , revision , quantize = quantize )
130+ return santacoder_cls (
131+ model_id ,
132+ revision ,
133+ quantize = quantize ,
134+ trust_remote_code = trust_remote_code ,
135+ )
112136
113- config = AutoConfig .from_pretrained (model_id , revision = revision )
137+ config = AutoConfig .from_pretrained (
138+ model_id , revision = revision , trust_remote_code = trust_remote_code
139+ )
114140 model_type = config .model_type
115141
116142 if model_type == "gpt_bigcode" :
@@ -119,52 +145,133 @@ def get_model(
119145 raise NotImplementedError (
120146 FLASH_ATT_ERROR_MESSAGE .format (f"Sharded Santacoder" )
121147 )
122- return FlashSantacoderSharded (model_id , revision , quantize = quantize )
148+ return FlashSantacoderSharded (
149+ model_id ,
150+ revision ,
151+ quantize = quantize ,
152+ trust_remote_code = trust_remote_code ,
153+ )
123154 else :
124155 santacoder_cls = FlashSantacoder if FLASH_ATTENTION else SantaCoder
125- return santacoder_cls (model_id , revision , quantize = quantize )
156+ return santacoder_cls (
157+ model_id ,
158+ revision ,
159+ quantize = quantize ,
160+ trust_remote_code = trust_remote_code ,
161+ )
126162
127163 if model_type == "bloom" :
128164 if sharded :
129- return BLOOMSharded (model_id , revision , quantize = quantize )
165+ return BLOOMSharded (
166+ model_id ,
167+ revision ,
168+ quantize = quantize ,
169+ trust_remote_code = trust_remote_code ,
170+ )
130171 else :
131- return BLOOM (model_id , revision , quantize = quantize )
172+ return BLOOM (
173+ model_id ,
174+ revision ,
175+ quantize = quantize ,
176+ trust_remote_code = trust_remote_code ,
177+ )
132178
133179 if model_type == "gpt_neox" :
134180 if sharded :
135181 neox_cls = FlashNeoXSharded if FLASH_ATTENTION else GPTNeoxSharded
136- return neox_cls (model_id , revision , quantize = quantize )
182+ return neox_cls (
183+ model_id ,
184+ revision ,
185+ quantize = quantize ,
186+ trust_remote_code = trust_remote_code ,
187+ )
137188 else :
138189 neox_cls = FlashNeoX if FLASH_ATTENTION else CausalLM
139- return neox_cls (model_id , revision , quantize = quantize )
190+ return neox_cls (
191+ model_id ,
192+ revision ,
193+ quantize = quantize ,
194+ trust_remote_code = trust_remote_code ,
195+ )
140196
141197 if model_type == "llama" :
142198 if sharded :
143199 if FLASH_ATTENTION :
144- return FlashLlamaSharded (model_id , revision , quantize = quantize )
200+ return FlashLlamaSharded (
201+ model_id ,
202+ revision ,
203+ quantize = quantize ,
204+ trust_remote_code = trust_remote_code ,
205+ )
145206 raise NotImplementedError (FLASH_ATT_ERROR_MESSAGE .format (f"Sharded Llama" ))
146207 else :
147208 llama_cls = FlashLlama if FLASH_ATTENTION else CausalLM
148- return llama_cls (model_id , revision , quantize = quantize )
209+ return llama_cls (
210+ model_id ,
211+ revision ,
212+ quantize = quantize ,
213+ trust_remote_code = trust_remote_code ,
214+ )
149215
150216 if config .model_type == "opt" :
151217 if sharded :
152- return OPTSharded (model_id , revision , quantize = quantize )
218+ return OPTSharded (
219+ model_id ,
220+ revision ,
221+ quantize = quantize ,
222+ trust_remote_code = trust_remote_code ,
223+ )
153224 else :
154- return OPT (model_id , revision , quantize = quantize )
225+ return OPT (
226+ model_id ,
227+ revision ,
228+ quantize = quantize ,
229+ trust_remote_code = trust_remote_code ,
230+ )
155231
156232 if model_type == "t5" :
157233 if sharded :
158- return T5Sharded (model_id , revision , quantize = quantize )
234+ return T5Sharded (
235+ model_id ,
236+ revision ,
237+ quantize = quantize ,
238+ trust_remote_code = trust_remote_code ,
239+ )
159240 else :
160- return Seq2SeqLM (model_id , revision , quantize = quantize )
241+ return Seq2SeqLM (
242+ model_id ,
243+ revision ,
244+ quantize = quantize ,
245+ trust_remote_code = trust_remote_code ,
246+ )
161247
162248 if sharded :
163249 raise ValueError ("sharded is not supported for AutoModel" )
164250
165251 if model_type in modeling_auto .MODEL_FOR_CAUSAL_LM_MAPPING_NAMES :
166- return CausalLM (model_id , revision , quantize = quantize )
252+ return CausalLM (
253+ model_id , revision , quantize = quantize , trust_remote_code = trust_remote_code
254+ )
167255 if model_type in modeling_auto .MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES :
168- return Seq2SeqLM (model_id , revision , quantize = quantize )
256+ return Seq2SeqLM (
257+ model_id , revision , quantize = quantize , trust_remote_code = trust_remote_code
258+ )
259+
260+ auto_map = getattr (config , "auto_map" , None )
261+ if trust_remote_code and auto_map is not None :
262+ if "AutoModelForCausalLM" in auto_map .keys ():
263+ return CausalLM (
264+ model_id ,
265+ revision ,
266+ quantize = quantize ,
267+ trust_remote_code = trust_remote_code ,
268+ )
269+ if "AutoModelForSeq2SeqLM" in auto_map .keys :
270+ return Seq2SeqLM (
271+ model_id ,
272+ revision ,
273+ quantize = quantize ,
274+ trust_remote_code = trust_remote_code ,
275+ )
169276
170277 raise ValueError (f"Unsupported model type { model_type } " )
0 commit comments