Skip to content

Commit dedca60

Browse files
authored
feat: validate transforms on compile (#151)
- Adds compile-time checks for transforms during encoderfile build stage - DRY mean pooling code a little Scope creep: - remove old config docs generation - new makefile commands
1 parent ef6c587 commit dedca60

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

46 files changed

+1749
-517
lines changed

.github/workflows/ci.yml

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ jobs:
119119
run: cargo binstall cargo-llvm-cov --force --no-confirm
120120

121121
- name: Generate coverage
122-
run: cargo llvm-cov --workspace --all-features --lcov --output-path lcov.info
122+
run: make coverage
123123

124124
- name: Upload to codecov
125125
uses: codecov/codecov-action@v2
@@ -163,7 +163,10 @@ jobs:
163163
with:
164164
toolchain: nightly
165165

166+
- name: Install test dependencies
167+
run: make setup
168+
166169
- run: rustup component add --toolchain nightly-x86_64-unknown-linux-gnu clippy
167170

168171
- name: Run Clippy
169-
run: cargo clippy --all-features --all-targets -- -D warnings
172+
run: make lint

.vscode/settings.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
{
2-
"rust-analyzer.cargo.features": ["bench"],
2+
"rust-analyzer.cargo.features": ["dev-utils"],
33
"rust-analyzer.procMacro.enable": true,
44
"rust-analyzer.cargo.buildScripts.enable": true,
55
"rust-analyzer.linkedProjects": [

Cargo.lock

Lines changed: 3 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Makefile

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,22 @@ format:
1212
@echo "Formatting rust..."
1313
@cargo fmt
1414

15+
.PHONY: lint
16+
lint:
17+
cargo clippy \
18+
--all-features \
19+
--all-targets \
20+
-- \
21+
-D warnings
22+
23+
.PHONY: coverage
24+
coverage:
25+
cargo llvm-cov \
26+
--workspace \
27+
--all-features \
28+
--lcov \
29+
--output-path lcov.info
30+
1531
,PHONY: licenses
1632
licenses:
1733
@echo "Generating licenses..."
@@ -59,5 +75,3 @@ generate-docs:
5975
# generate JSON schema for encoderfile config
6076
@cargo run \
6177
--bin generate-encoderfile-config-schema
62-
# generate CLI docs for encoderfile build
63-
@cargo run --bin generate-encoderfile-cli-docs --features="_internal"

docs/reference/cli.md

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,10 @@ encoderfile:
7171
transform:
7272
path: ./transforms/normalize.lua
7373
# OR inline transform:
74-
# transform: "return normalize(output)"
74+
# transform: "function Postprocess(logits) return logits:lp_normalize(2.0, 2.0) end"
75+
76+
# Whether to validate transform with a dry-run (optional, defaults to true)
77+
validate_transform: true
7578

7679
# Whether to build the binary (optional, defaults to true)
7780
build: true

encoderfile-core/benches/benchmark_transforms.rs

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
use encoderfile_core::transforms::Postprocessor;
12
use ndarray::{Array2, Array3};
23
use rand::Rng;
34

@@ -21,9 +22,9 @@ fn get_random_3d(x: usize, y: usize, z: usize) -> Array3<f32> {
2122

2223
#[divan::bench(args = [(16, 16, 16), (32, 128, 384), (32, 256, 768)])]
2324
fn bench_embedding_l2_normalization(bencher: divan::Bencher, (x, y, z): (usize, usize, usize)) {
24-
let engine = encoderfile_core::transforms::Transform::new(include_str!(
25+
let engine = encoderfile_core::transforms::EmbeddingTransform::new(Some(include_str!(
2526
"../../transforms/embedding/l2_normalize_embeddings.lua"
26-
))
27+
)))
2728
.unwrap();
2829

2930
let test_tensor = get_random_3d(x, y, z);
@@ -35,8 +36,8 @@ fn bench_embedding_l2_normalization(bencher: divan::Bencher, (x, y, z): (usize,
3536

3637
#[divan::bench(args = [(16, 2), (32, 8), (128, 32)])]
3738
fn bench_seq_cls_softmax(bencher: divan::Bencher, (x, y): (usize, usize)) {
38-
let engine = encoderfile_core::transforms::Transform::new(include_str!(
39-
"../../transforms/sequence_classification/softmax_logits.lua"
39+
let engine = encoderfile_core::transforms::SequenceClassificationTransform::new(Some(
40+
include_str!("../../transforms/sequence_classification/softmax_logits.lua"),
4041
))
4142
.unwrap();
4243

@@ -49,8 +50,8 @@ fn bench_seq_cls_softmax(bencher: divan::Bencher, (x, y): (usize, usize)) {
4950

5051
#[divan::bench(args = [(16, 16, 2), (32, 128, 8), (128, 256, 32)])]
5152
fn bench_tok_cls_softmax(bencher: divan::Bencher, (x, y, z): (usize, usize, usize)) {
52-
let engine = encoderfile_core::transforms::Transform::new(include_str!(
53-
"../../transforms/token_classification/softmax_logits.lua"
53+
let engine = encoderfile_core::transforms::TokenClassificationTransform::new(Some(
54+
include_str!("../../transforms/token_classification/softmax_logits.lua"),
5455
))
5556
.unwrap();
5657

encoderfile-core/src/cli.rs

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,18 +3,50 @@ use crate::{
33
EmbeddingRequest, ModelType, SentenceEmbeddingRequest, SequenceClassificationRequest,
44
TokenClassificationRequest,
55
},
6-
runtime::AppState,
6+
runtime::{AppState, get_model, get_model_config, get_model_type, get_tokenizer},
77
server::{run_grpc, run_http, run_mcp},
88
services::{embedding, sentence_embedding, sequence_classification, token_classification},
99
};
1010
use anyhow::Result;
11+
use clap::Parser;
1112
use clap_derive::{Parser, Subcommand, ValueEnum};
1213
use opentelemetry::trace::TracerProvider as _;
1314
use opentelemetry_otlp::{Protocol, WithExportConfig};
1415
use opentelemetry_sdk::trace::SdkTracerProvider;
1516
use std::{fmt::Display, io::Write};
1617
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
1718

19+
pub async fn cli_entrypoint(
20+
model_bytes: &[u8],
21+
config_str: &str,
22+
tokenizer_json: &str,
23+
model_type: &str,
24+
model_id: &str,
25+
transform_str: Option<&str>,
26+
) -> Result<()> {
27+
let cli = Cli::parse();
28+
29+
let session = get_model(model_bytes);
30+
let config = get_model_config(config_str);
31+
let tokenizer = get_tokenizer(tokenizer_json, &config);
32+
let model_type = get_model_type(model_type);
33+
let transform_str = transform_str.map(|t| t.to_string());
34+
let model_id = model_id.to_string();
35+
36+
let state = AppState {
37+
session,
38+
config,
39+
tokenizer,
40+
model_type,
41+
model_id,
42+
transform_str,
43+
};
44+
45+
cli.command.execute(state).await?;
46+
47+
Ok(())
48+
}
49+
1850
macro_rules! generate_cli_route {
1951
($req:ident, $fn:path, $format:ident, $out_dir:expr, $state:expr) => {{
2052
let result = $fn($req, &$state)?;

encoderfile-core/src/common/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
mod embedding;
2+
mod model_config;
23
mod model_metadata;
34
mod model_type;
45
mod sentence_embedding;
@@ -7,6 +8,7 @@ mod token;
78
mod token_classification;
89

910
pub use embedding::*;
11+
pub use model_config::*;
1012
pub use model_metadata::*;
1113
pub use model_type::*;
1214
pub use sentence_embedding::*;
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
use serde::{Deserialize, Serialize};
2+
use std::collections::HashMap;
3+
4+
#[derive(Debug, Serialize, Deserialize)]
5+
pub struct ModelConfig {
6+
pub model_type: String,
7+
pub pad_token_id: u32,
8+
pub num_labels: Option<usize>,
9+
pub id2label: Option<HashMap<u32, String>>,
10+
pub label2id: Option<HashMap<String, u32>>,
11+
}
12+
13+
impl ModelConfig {
14+
pub fn id2label(&self, id: u32) -> Option<&str> {
15+
self.id2label.as_ref()?.get(&id).map(|s| s.as_str())
16+
}
17+
18+
pub fn label2id(&self, label: &str) -> Option<u32> {
19+
self.label2id.as_ref()?.get(label).copied()
20+
}
21+
22+
pub fn num_labels(&self) -> Option<usize> {
23+
if self.num_labels.is_some() {
24+
return self.num_labels;
25+
}
26+
27+
if let Some(id2label) = &self.id2label {
28+
return Some(id2label.len());
29+
}
30+
31+
if let Some(label2id) = &self.label2id {
32+
return Some(label2id.len());
33+
}
34+
35+
None
36+
}
37+
}
38+
39+
#[cfg(test)]
40+
mod tests {
41+
use super::*;
42+
43+
#[test]
44+
fn test_num_labels() {
45+
let test_labels: Vec<(String, u32)> = vec![("a", 1), ("b", 2), ("c", 3)]
46+
.into_iter()
47+
.map(|(i, j)| (i.to_string(), j))
48+
.collect();
49+
50+
let label2id: HashMap<String, u32> = test_labels.clone().into_iter().collect();
51+
let id2label: HashMap<u32, String> = test_labels
52+
.clone()
53+
.into_iter()
54+
.map(|(i, j)| (j, i))
55+
.collect();
56+
57+
let config = ModelConfig {
58+
model_type: "MyModel".to_string(),
59+
pad_token_id: 0,
60+
num_labels: Some(3),
61+
id2label: Some(id2label.clone()),
62+
label2id: Some(label2id.clone()),
63+
};
64+
65+
assert_eq!(config.num_labels(), Some(3));
66+
67+
let config = ModelConfig {
68+
model_type: "MyModel".to_string(),
69+
pad_token_id: 0,
70+
num_labels: None,
71+
id2label: Some(id2label.clone()),
72+
label2id: Some(label2id.clone()),
73+
};
74+
75+
assert_eq!(config.num_labels(), Some(3));
76+
77+
let config = ModelConfig {
78+
model_type: "MyModel".to_string(),
79+
pad_token_id: 0,
80+
num_labels: None,
81+
id2label: None,
82+
label2id: Some(label2id.clone()),
83+
};
84+
85+
assert_eq!(config.num_labels(), Some(3));
86+
}
87+
}
Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
#[derive(Debug, Clone, serde::Serialize, utoipa::ToSchema)]
22
#[serde(rename_all = "snake_case")]
3+
#[repr(u8)]
34
pub enum ModelType {
4-
Embedding,
5-
SequenceClassification,
6-
TokenClassification,
7-
SentenceEmbedding,
5+
Embedding = 1,
6+
SequenceClassification = 2,
7+
TokenClassification = 3,
8+
SentenceEmbedding = 4,
89
}

0 commit comments

Comments
 (0)