Skip to content

Commit 982ce32

Browse files
feat(router): explicit warning if revision is not set (#608)
1 parent b732720 commit 982ce32

File tree

2 files changed

+31
-17
lines changed

2 files changed

+31
-17
lines changed

launcher/src/main.rs

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -760,16 +760,6 @@ fn spawn_shards(
760760
status_sender: mpsc::Sender<ShardStatus>,
761761
running: Arc<AtomicBool>,
762762
) -> Result<(), LauncherError> {
763-
if args.trust_remote_code {
764-
tracing::warn!(
765-
"`trust_remote_code` is set. Trusting that model `{}` do not contain malicious code.",
766-
args.model_id
767-
);
768-
if args.revision.is_none() {
769-
tracing::warn!("Explicitly passing a `revision` is encouraged when loading a model with custom code to ensure no malicious code has been contributed in a newer revision.");
770-
}
771-
}
772-
773763
// Start shard processes
774764
for rank in 0..num_shard {
775765
let model_id = args.model_id.clone();
@@ -1025,6 +1015,12 @@ fn main() -> Result<(), LauncherError> {
10251015
"`validation_workers` must be > 0".to_string(),
10261016
));
10271017
}
1018+
if args.trust_remote_code {
1019+
tracing::warn!(
1020+
"`trust_remote_code` is set. Trusting that model `{}` do not contain malicious code.",
1021+
args.model_id
1022+
);
1023+
}
10281024

10291025
let num_shard = find_num_shards(args.sharded, args.num_shard)?;
10301026
if num_shard > 1 {

router/src/main.rs

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,8 @@ struct Args {
4949
master_shard_uds_path: String,
5050
#[clap(default_value = "bigscience/bloom", long, env)]
5151
tokenizer_name: String,
52-
#[clap(default_value = "main", long, env)]
53-
revision: String,
52+
#[clap(long, env)]
53+
revision: Option<String>,
5454
#[clap(default_value = "2", long, env)]
5555
validation_workers: usize,
5656
#[clap(long, env)]
@@ -147,7 +147,7 @@ fn main() -> Result<(), RouterError> {
147147
// Download and instantiate tokenizer
148148
// We need to download it outside of the Tokio runtime
149149
let params = FromPretrainedParameters {
150-
revision: revision.clone(),
150+
revision: revision.clone().unwrap_or("main".to_string()),
151151
auth_token: authorization_token.clone(),
152152
..Default::default()
153153
};
@@ -175,7 +175,7 @@ fn main() -> Result<(), RouterError> {
175175
sha: None,
176176
pipeline_tag: None,
177177
},
178-
false => get_model_info(&tokenizer_name, &revision, authorization_token)
178+
false => get_model_info(&tokenizer_name, revision, authorization_token)
179179
.await
180180
.unwrap_or_else(|| {
181181
tracing::warn!("Could not retrieve model info from the Hugging Face hub.");
@@ -316,9 +316,18 @@ fn init_logging(otlp_endpoint: Option<String>, json_output: bool) {
316316
/// get model info from the Huggingface Hub
317317
pub async fn get_model_info(
318318
model_id: &str,
319-
revision: &str,
319+
revision: Option<String>,
320320
token: Option<String>,
321321
) -> Option<HubModelInfo> {
322+
let revision = match revision {
323+
None => {
324+
tracing::warn!("`--revision` is not set");
325+
tracing::warn!("We strongly advise to set it to a known supported commit.");
326+
"main".to_string()
327+
}
328+
Some(revision) => revision,
329+
};
330+
322331
let client = reqwest::Client::new();
323332
// Poor man's urlencode
324333
let revision = revision.replace('/', "%2F");
@@ -331,9 +340,18 @@ pub async fn get_model_info(
331340
let response = builder.send().await.ok()?;
332341

333342
if response.status().is_success() {
334-
return serde_json::from_str(&response.text().await.ok()?).ok();
343+
let hub_model_info: HubModelInfo =
344+
serde_json::from_str(&response.text().await.ok()?).ok()?;
345+
if let Some(sha) = &hub_model_info.sha {
346+
tracing::info!(
347+
"Serving revision {sha} of model {}",
348+
hub_model_info.model_id
349+
);
350+
}
351+
Some(hub_model_info)
352+
} else {
353+
None
335354
}
336-
None
337355
}
338356

339357
#[derive(Debug, Error)]

0 commit comments

Comments
 (0)