Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 11 additions & 12 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,31 +46,30 @@ CambApiClient client = CambApiClient.builder()
.build();
```

### Client with Specific MARS Pro Provider (e.g. Baseten)
### Custom Hosting Provider (e.g. Baseten Mars8-Flash)

You can use a custom hosting provider like Baseten for specialized deployments.
You can route TTS through a custom hosting provider like Baseten while keeping the same SDK interface.
`reference_audio` can be a public URL or base64-encoded audio file — Baseten caches it for faster inference.

```java
import resources.texttospeech.requests.CreateStreamTtsRequestPayload;
import resources.texttospeech.types.CreateStreamTtsRequestPayloadLanguage;
import resources.texttospeech.types.CreateStreamTtsRequestPayloadSpeechModel;
import types.OutputFormat;
import types.StreamTtsOutputConfiguration;
import java.io.InputStream;

// Initialize custom hosting provider
// Initialize the Baseten Mars8-Flash custom hosting provider.
// BASETEN_REFERENCE_AUDIO can be a public URL or base64-encoded audio file.
ITtsProvider ttsProvider = new BasetenProvider(
"YOUR_BASETEN_API_KEY",
"YOUR_BASETEN_URL"
System.getenv("BASETEN_API_KEY"),
System.getenv("BASETEN_URL"),
System.getenv("BASETEN_REFERENCE_AUDIO"), // reference voice
"en-us" // reference audio language
);

// Use the provider to generate speech
InputStream audioStream = ttsProvider.tts(CreateStreamTtsRequestPayload.builder()
.text("Hello from Java via Baseten!")
.text("Hello from Java via Baseten Mars8-Flash!")
.language(CreateStreamTtsRequestPayloadLanguage.EN_US)
.voiceId(1) // Required but ignored by custom hosting provider
.speechModel(CreateStreamTtsRequestPayloadSpeechModel.MARSPRO)
.outputConfiguration(StreamTtsOutputConfiguration.builder().format(OutputFormat.WAV).build())
.voiceId(1) // Required by the SDK's staged builder; ignored by the Baseten provider
.build(), null);
```

Expand Down
76 changes: 52 additions & 24 deletions examples/BasetenExample.java
Original file line number Diff line number Diff line change
@@ -1,46 +1,72 @@
import resources.texttospeech.requests.CreateStreamTtsRequestPayload;
import resources.texttospeech.types.CreateStreamTtsRequestPayloadLanguage;
import resources.texttospeech.types.CreateStreamTtsRequestPayloadSpeechModel;
import types.OutputFormat;
import types.StreamTtsOutputConfiguration;
import java.io.InputStream;
import java.io.FileOutputStream;
import java.io.File;

/**
* Example: Baseten Mars8-Flash custom hosting provider via the Camb.ai Java SDK.
*
* Required environment variables:
* CAMB_API_KEY - Your Camb.ai API key
* BASETEN_API_KEY - Your Baseten API key
* BASETEN_URL - Your Baseten model prediction endpoint URL
* BASETEN_REFERENCE_AUDIO - Reference voice audio: public URL or base64-encoded file
* e.g. https://github.com/Camb-ai/mars6-turbo/raw/refs/heads/master/assets/example.wav
*
* Optional:
* BASETEN_REFERENCE_LANGUAGE - ISO locale of the reference audio (default: en-us)
*
* API reference: https://www.baseten.co/library/mars8-flash/
*/
public class BasetenExample {
public static void main(String[] args) {
// Environment variables for Baseten and Camb AI
String cambApiKey = System.getenv("CAMB_API_KEY");
String basetenApiKey = System.getenv("BASETEN_API_KEY");
String basetenUrl = System.getenv("BASETEN_URL");

if (cambApiKey == null || basetenApiKey == null || basetenUrl == null) {
System.err.println("Error: Missing required environment variables:");
if (cambApiKey == null) System.err.println(" - CAMB_API_KEY");
if (basetenApiKey == null) System.err.println(" - BASETEN_API_KEY");
if (basetenUrl == null) System.err.println(" - BASETEN_URL (e.g. your Baseten model endpoint URL)");
String cambApiKey = System.getenv("CAMB_API_KEY");
String basetenApiKey = System.getenv("BASETEN_API_KEY");
String basetenUrl = System.getenv("BASETEN_URL");
String referenceAudio = System.getenv("BASETEN_REFERENCE_AUDIO");
String referenceLanguage = System.getenv("BASETEN_REFERENCE_LANGUAGE");

// Loud fail for missing required env vars
boolean hasError = false;
if (cambApiKey == null) { System.err.println(" - CAMB_API_KEY"); hasError = true; }
if (basetenApiKey == null) { System.err.println(" - BASETEN_API_KEY"); hasError = true; }
if (basetenUrl == null) { System.err.println(" - BASETEN_URL (your Baseten model prediction endpoint)"); hasError = true; }
if (referenceAudio == null) { System.err.println(" - BASETEN_REFERENCE_AUDIO (public URL or base64-encoded audio file)"); hasError = true; }

if (hasError) {
System.err.println("Error: Missing required environment variables (see above).");
System.exit(1);
}

// Initialize the custom Baseten provider
ITtsProvider basetenProvider = new BasetenProvider(basetenApiKey, basetenUrl);
// Default reference language to en-us if not set
if (referenceLanguage == null || referenceLanguage.isEmpty()) {
referenceLanguage = "en-us";
}

System.out.println("Generating speech via Baseten provider...");
// Initialise the Baseten custom hosting provider
ITtsProvider basetenProvider = new BasetenProvider(
basetenApiKey,
basetenUrl,
referenceAudio,
referenceLanguage
);

System.out.println("Generating speech via Baseten Mars8-Flash custom hosting provider...");

try {
// Build the payload
// Build the TTS request.
// voiceId is required by the Camb SDK payload builder but is ignored
// when routing through a custom hosting provider.
CreateStreamTtsRequestPayload request = CreateStreamTtsRequestPayload.builder()
.text("Hello. This is speech generated using a custom Baseten provider.")
.text("Hello. This is speech generated via a Baseten Mars8-Flash custom hosting provider.")
.language(CreateStreamTtsRequestPayloadLanguage.EN_US)
.voiceId(1) // Ignored by custom hosting provider but required by payload
.speechModel(CreateStreamTtsRequestPayloadSpeechModel.MARSPRO)
.outputConfiguration(StreamTtsOutputConfiguration.builder().format(OutputFormat.WAV).build())
.voiceId(1) // Required by the SDK's staged builder; ignored by the Baseten provider
.build();

// Use the provider
InputStream audioStream = basetenProvider.tts(request, null);

File outputFile = new File("baseten_output.wav");
File outputFile = new File("baseten_output.flac");
try (FileOutputStream outputStream = new FileOutputStream(outputFile)) {
byte[] buffer = new byte[4096];
int bytesRead;
Expand All @@ -49,9 +75,11 @@ public static void main(String[] args) {
}
}

System.out.println("✓ Success! Baseten generated audio saved to " + outputFile.getAbsolutePath());
System.out.println("✓ Success! Audio saved to " + outputFile.getAbsolutePath());
} catch (Exception e) {
System.err.println("Error: " + e.getMessage());
e.printStackTrace();
System.exit(1);
}
}
}
66 changes: 34 additions & 32 deletions examples/BasetenProvider.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@


import resources.texttospeech.requests.CreateStreamTtsRequestPayload;
import core.RequestOptions;
import java.io.InputStream;
Expand All @@ -13,57 +12,59 @@
import java.util.HashMap;
import java.util.Map;

/**
* Baseten TTS Provider using the Mars8-Flash model.
*
* API reference: https://www.baseten.co/library/mars8-flash/
*
* Constructor parameters:
* apiKey - Baseten API key
* url - Baseten model prediction endpoint
* referenceAudio - Reference voice: public URL or base64-encoded audio file
* referenceLanguage - ISO locale of the reference audio (e.g. "en-us")
*/
public class BasetenProvider implements ITtsProvider {
private final String apiKey;
private final String url;
private final String referenceAudio;
private final String referenceLanguage;
private final OkHttpClient httpClient;
private final ObjectMapper objectMapper;

public BasetenProvider(String apiKey, String url) {
public BasetenProvider(String apiKey, String url, String referenceAudio, String referenceLanguage) {
this.apiKey = apiKey;
this.url = url != null ? url : "https://model-5qeryx53.api.baseten.co/environments/production/predict";
this.url = url;
this.referenceAudio = referenceAudio;
this.referenceLanguage = referenceLanguage;
this.httpClient = new OkHttpClient();
this.objectMapper = new ObjectMapper();
}

@Override
public InputStream tts(CreateStreamTtsRequestPayload request, RequestOptions requestOptions) {
// NOTE: In a real scenario, you'd extend CreateStreamTtsRequestPayload or use a custom DTO
// to pass reference_audio/language. For this example, we assume they are passed via a side channel
// or we'd cast a custom subclass.
// Java's strong typing makes this trickier than JS/Python without altering the generated class.
// Assuming we have reference audio/language from somewhere:

String referenceAudio = "DUMMY_BASE64..."; // Placeholder
String referenceLanguage = "en-us";
// Normalise language: SDK enum is a string type, ensure lowercase ISO format.
String language = request.getLanguage().toString().toLowerCase().replace("_", "-");

// Build the Mars8-Flash payload.
// Docs: https://www.baseten.co/library/mars8-flash/
Map<String, Object> payload = new HashMap<>();
payload.put("text", request.getText());
payload.put("stream", true);

// Use format from request if provided, otherwise default to wav
String format = request.getOutputConfiguration()
.flatMap(config -> config.getFormat())
.map(f -> f.toString())
.orElse("wav");
payload.put("output_format", format);

payload.put("language", request.getLanguage().toString());

// Use speech model from request if provided
request.getSpeechModel().ifPresent(model -> {
payload.put("speech_model", model.toString());
});

payload.put("language", language);
payload.put("output_duration", null); // null = model infers optimal duration
payload.put("reference_audio", referenceAudio);
payload.put("audio_ref", referenceAudio);
payload.put("reference_language", referenceLanguage);
payload.put("apply_ner_nlp", false);
payload.put("output_format", "flac"); // flac is the default; wav also supported
payload.put("apply_ner_nlp", false); // disable NER (faster; pass pronunciation_dictionary instead)

// Optional: override output format from request output configuration
request.getOutputConfiguration().ifPresent(config ->
config.getFormat().ifPresent(f -> payload.put("output_format", f.toString().toLowerCase()))
);

try {
String json = objectMapper.writeValueAsString(payload);
RequestBody body = RequestBody.create(json, MediaType.parse("application/json"));

Request req = new Request.Builder()
.url(this.url)
.addHeader("Authorization", "Api-Key " + this.apiKey)
Expand All @@ -72,12 +73,13 @@ public InputStream tts(CreateStreamTtsRequestPayload request, RequestOptions req

Response response = httpClient.newCall(req).execute();
if (!response.isSuccessful()) {
throw new RuntimeException("Baseten Error: " + response.code());
String errorBody = response.body() != null ? response.body().string() : "<no body>";
throw new RuntimeException("Baseten API error " + response.code() + ": " + errorBody);
}
return response.body().byteStream();

} catch (IOException e) {
throw new RuntimeException("Network error", e);
throw new RuntimeException("Network error calling Baseten: " + e.getMessage(), e);
}
}
}
Loading