Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 94 additions & 1 deletion extensions/src/doodlebot/Doodlebot.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import { base64ToInt32Array, deferred, makeWebsocket, Max32Int } from "./utils";
import { LineDetector } from "./LineDetection";
import { calculateArcTime } from "./TimeHelper";
import type { BLEDeviceWithUartService } from "./ble";
import { saveAudioBufferToWav } from "./SoundUtils";

export type MotorStepRequest = {
/** Number of steps for the stepper (+ == forward, - == reverse) */
Expand Down Expand Up @@ -287,7 +288,7 @@ export default class Doodlebot {
private updateSensor<T extends SensorKey>(type: T, value: SensorData[T]) {
this.onSensor.emit(type, value);
if (type == "distance" && value as number >= 63) {
this.sensorData[type] = 100;
(this.sensorData[type] as number) = 100;
} else {
this.sensorData[type] = value;
}
Expand Down Expand Up @@ -1119,6 +1120,62 @@ export default class Doodlebot {
};
}

async sendAudioFileToChatEndpoint(file, endpoint, voice_id, pitch_value) {
console.log("sending audio file");
const url = `https://doodlebot.media.mit.edu/${endpoint}?voice=${voice_id}&pitch=${pitch_value}`
const formData = new FormData();
formData.append("audio_file", file);
const audioURL = URL.createObjectURL(file);
const audio = new Audio(audioURL);
//audio.play();

try {
let response;
let uint8array;

response = await fetch(url, {
method: "POST",
body: formData,
});

if (!response.ok) {
const errorText = await response.text();
console.log("Error response:", errorText);
throw new Error(`HTTP error! status: ${response.status}`);
}

const textResponse = response.headers.get("text-response");
console.log("Text Response:", textResponse);

const blob = await response.blob();
const audioUrl = URL.createObjectURL(blob);
console.log("Audio URL:", audioUrl);

const audio = new Audio(audioUrl);
const array = await blob.arrayBuffer();
uint8array = new Uint8Array(array);

this.sendAudioData(uint8array);


} catch (error) {
console.error("Error sending audio file:", error);
}
}


async processAndSendAudio(buffer, endpoint, voice_id, pitch_value) {
try {
const wavBlob = await saveAudioBufferToWav(buffer);
console.log(wavBlob);
const wavFile = new File([wavBlob], "output.wav", { type: "audio/wav" });

await this.sendAudioFileToChatEndpoint(wavFile, endpoint, voice_id, pitch_value);
} catch (error) {
console.error("Error processing and sending audio:", error);
}
}

recordAudio(numSeconds = 1) {
if (!this.setupAudioStream()) return;

Expand Down Expand Up @@ -1237,4 +1294,40 @@ export default class Doodlebot {
this.pending[type] = undefined;
return value;
}

async uploadFile(type: string, blobURL: string) {
const tld = await this.topLevelDomain.promise;
let uploadEndpoint;
if (type == "sound") {
uploadEndpoint = "https://" + tld + "/api/v1/upload/sounds_upload";
} else {
uploadEndpoint = "https://" + tld + "/api/v1/upload/img_upload";
}

try {
const components = blobURL.split("---name---");
const response1 = await fetch(components[1]);
if (!response1.ok) {
throw new Error(`Failed to fetch Blob from URL: ${blobURL}`);
}
const blob = await response1.blob();
// Convert Blob to File
const file = new File([blob], components[0], { type: blob.type });
const formData = new FormData();
formData.append("file", file);

const response2 = await fetch(uploadEndpoint, {
method: "POST",
body: formData,
});


if (!response2.ok) {
throw new Error(`Failed to upload file: ${response2.statusText}`);
}

} catch (error) {
console.error("Error:", error);
}
}
}
212 changes: 212 additions & 0 deletions extensions/src/doodlebot/ModelUtils.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,212 @@
import tmPose from '@teachablemachine/pose';
import tmImage from '@teachablemachine/image';
import * as speechCommands from '@tensorflow-models/speech-commands';



export default class TeachableMachine {
latestAudioResults: any;
predictionState;
modelConfidences = {};
maxConfidence: number = null;
teachableImageModel;
isPredicting: number = 0;
ModelType = {
POSE: 'pose',
IMAGE: 'image',
AUDIO: 'audio',
};

constructor() {
this.predictionState = {};
}

useModel = async (url: string): Promise<{type: "success" | "error" | "warning", msg: string}> => {
try {
const modelUrl = this.modelArgumentToURL(url);
console.log('Loading model from URL:', modelUrl);

// Initialize prediction state if needed
this.predictionState[modelUrl] = {};

// Load and initialize the model
const { model, type } = await this.initModel(modelUrl);

this.predictionState[modelUrl].modelType = type;
this.predictionState[modelUrl].model = model;

// Update the current model reference
this.teachableImageModel = modelUrl;

return {
type: "success",
msg: "Model loaded successfully"
};
} catch (e) {
console.error('Error loading model:', e);
this.teachableImageModel = null;
return {
type: "error",
msg: "Failed to load model"
};
}
}


modelArgumentToURL = (modelArg: string) => {
// Convert user-provided model URL/ID to the correct format
const endpointProvidedFromInterface = "https://teachablemachine.withgoogle.com/models/";
const redirectEndpoint = "https://storage.googleapis.com/tm-model/";

return modelArg.startsWith(endpointProvidedFromInterface)
? modelArg.replace(endpointProvidedFromInterface, redirectEndpoint)
: redirectEndpoint + modelArg + "/";
}

initModel = async (modelUrl: string) => {
const avoidCache = `?x=${Date.now()}`;
const modelURL = modelUrl + "model.json" + avoidCache;
const metadataURL = modelUrl + "metadata.json" + avoidCache;

// First try loading as an image model
try {
const customMobileNet = await tmImage.load(modelURL, metadataURL);

// Check if it's actually an audio model
if ((customMobileNet as any)._metadata.hasOwnProperty('tfjsSpeechCommandsVersion')) {
const recognizer = await speechCommands.create("BROWSER_FFT", undefined, modelURL, metadataURL);
await recognizer.ensureModelLoaded();

// Setup audio listening
await recognizer.listen(async result => {
this.latestAudioResults = result;
}, {
includeSpectrogram: true,
probabilityThreshold: 0.75,
invokeCallbackOnNoiseAndUnknown: true,
overlapFactor: 0.50
});

return { model: recognizer, type: this.ModelType.AUDIO };
}
// Check if it's a pose model
else if ((customMobileNet as any)._metadata.packageName === "@teachablemachine/pose") {
const customPoseNet = await tmPose.load(modelURL, metadataURL);
return { model: customPoseNet, type: this.ModelType.POSE };
}
// Otherwise it's an image model
else {
return { model: customMobileNet, type: this.ModelType.IMAGE };
}
} catch (e) {
console.error("Failed to load model:", e);
throw e;
}
}

getPredictionFromModel = async (modelUrl: string, frame: ImageBitmap) => {
const { model, modelType } = this.predictionState[modelUrl];
switch (modelType) {
case this.ModelType.IMAGE:
if (!frame) return null;
return await model.predict(frame);
case this.ModelType.POSE:
if (!frame) return null;
const { pose, posenetOutput } = await model.estimatePose(frame);
return await model.predict(posenetOutput);
case this.ModelType.AUDIO:
if (this.latestAudioResults) {
return model.wordLabels().map((label, i) => ({
className: label,
probability: this.latestAudioResults.scores[i]
}));
}
return null;
}
}

private getPredictionStateOrStartPredicting(modelUrl: string) {
if (!modelUrl || !this.predictionState || !this.predictionState[modelUrl]) {
console.warn('No prediction state available for model:', modelUrl);
return null;
}
return this.predictionState[modelUrl];
}

model_match(state) {
const modelUrl = this.teachableImageModel;
const className = state;

const predictionState = this.getPredictionStateOrStartPredicting(modelUrl);
if (!predictionState) {
return false;
}

const currentMaxClass = predictionState.topClass;
return (currentMaxClass === String(className));
}

getModelClasses(): string[] {
if (
!this.teachableImageModel ||
!this.predictionState ||
!this.predictionState[this.teachableImageModel] ||
!this.predictionState[this.teachableImageModel].hasOwnProperty('model')
) {
return ["Select a class"];
}

if (this.predictionState[this.teachableImageModel].modelType === this.ModelType.AUDIO) {
return this.predictionState[this.teachableImageModel].model.wordLabels();
}

return this.predictionState[this.teachableImageModel].model.getClassLabels();
}

getModelPrediction() {
const modelUrl = this.teachableImageModel;
const predictionState: { topClass: string } = this.getPredictionStateOrStartPredicting(modelUrl);
if (!predictionState) {
console.error("No prediction state found");
return '';
}
return predictionState.topClass;
}

async predictAllBlocks(frame: ImageBitmap) {
for (let modelUrl in this.predictionState) {
if (!this.predictionState[modelUrl].model) {
console.log('No model found for:', modelUrl);
continue;
}
if (this.teachableImageModel !== modelUrl) {
console.log('Model URL mismatch:', modelUrl);
continue;
}
++this.isPredicting;
const prediction = await this.predictModel(modelUrl, frame);
this.predictionState[modelUrl].topClass = prediction;
--this.isPredicting;
}
}

private async predictModel(modelUrl: string, frame: ImageBitmap) {
const predictions = await this.getPredictionFromModel(modelUrl, frame);
if (!predictions) {
return;
}
let maxProbability = 0;
let maxClassName = "";
for (let i = 0; i < predictions.length; i++) {
const probability = predictions[i].probability.toFixed(2);
const className = predictions[i].className;
this.modelConfidences[className] = probability;
if (probability > maxProbability) {
maxClassName = className;
maxProbability = probability;
}
}
this.maxConfidence = maxProbability;
return maxClassName;
}
}
Loading
Loading