Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions extensions/src/doodlebot/Connect.svelte
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@
const invoke: ReactiveInvoke<Extension> = (functionName, ...args) =>
reactiveInvoke((extension = extension), functionName, args);

let connected = extension.connected;

let error: string | null = null;

let bleDevice: BLEDeviceWithUartService | null = null;
Expand Down Expand Up @@ -68,7 +66,7 @@
let showAdvanced = false;

onDestroy(() => {
if (!connected)
if (!extension.connected)
try {
extension.setIndicator("disconnected");
} catch (e) {}
Expand All @@ -81,7 +79,7 @@
style:background-color={color.ui.white}
style:color={color.text.primary}
>
{#if connected}
{#if extension.connected}
<h1>You're connected to doodlebot!</h1>
<div>
If you'd like to reconnect, or connect to a different device, you must
Expand Down
118 changes: 117 additions & 1 deletion extensions/src/doodlebot/Doodlebot.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import { base64ToInt32Array, deferred, makeWebsocket, Max32Int } from "./utils";
import { LineDetector } from "./LineDetection";
import { calculateArcTime } from "./TimeHelper";
import type { BLEDeviceWithUartService } from "./ble";
import { saveAudioBufferToWav } from "./SoundUtils";

export type MotorStepRequest = {
/** Number of steps for the stepper (+ == forward, - == reverse) */
Expand Down Expand Up @@ -287,7 +288,7 @@ export default class Doodlebot {
private updateSensor<T extends SensorKey>(type: T, value: SensorData[T]) {
this.onSensor.emit(type, value);
if (type == "distance" && value as number >= 63) {
this.sensorData[type] = 100;
(this.sensorData[type] as number) = 100;
} else {
this.sensorData[type] = value;
}
Expand Down Expand Up @@ -1119,6 +1120,85 @@ export default class Doodlebot {
};
}

async sendAudioFileToChatEndpoint(file, endpoint, voice_id, pitch_value) {
console.log("sending audio file");
const url = `https://doodlebot.media.mit.edu/${endpoint}?voice=${voice_id}&pitch=${pitch_value}`
const formData = new FormData();
formData.append("audio_file", file);
const audioURL = URL.createObjectURL(file);
const audio = new Audio(audioURL);
//audio.play();

try {
let response;
let uint8array;

response = await fetch(url, {
method: "POST",
body: formData,
});

if (!response.ok) {
const errorText = await response.text();
console.log("Error response:", errorText);
throw new Error(`HTTP error! status: ${response.status}`);
}

const textResponse = response.headers.get("text-response");
console.log("Text Response:", textResponse);

const blob = await response.blob();
const audioUrl = URL.createObjectURL(blob);
console.log("Audio URL:", audioUrl);

const audio = new Audio(audioUrl);
const array = await blob.arrayBuffer();
uint8array = new Uint8Array(array);

this.sendAudioData(uint8array);


} catch (error) {
console.error("Error sending audio file:", error);
}
}


async processAndSendAudio(buffer, endpoint, voice_id, pitch_value) {
try {
const wavBlob = await saveAudioBufferToWav(buffer);
console.log(wavBlob);
const wavFile = new File([wavBlob], "output.wav", { type: "audio/wav" });

await this.sendAudioFileToChatEndpoint(wavFile, endpoint, voice_id, pitch_value);
} catch (error) {
console.error("Error processing and sending audio:", error);
}
}


// private async handleChatInteraction(seconds: number, endpoint: string) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

drop?

// console.log(`recording audio for ${seconds} seconds`);
// // Display "listening" while recording
// await this.doodlebot?.display("clear");
// await this.doodlebot?.displayText("listening");
// console.log("recording audio?")
// const { context, buffer } = await this.doodlebot?.recordAudio(seconds);
// console.log("finished recording audio");

// // Display "thinking" while processing and waiting for response
// await this.doodlebot?.display("clear");
// await this.doodlebot?.displayText("thinking");

// // Before sending audio to be played
// await this.processAndSendAudio(buffer, endpoint, seconds);

// // Wait a moment before clearing the display
// await new Promise(resolve => setTimeout(resolve, 2000));
// await this.doodlebot?.display("h");
// }


recordAudio(numSeconds = 1) {
if (!this.setupAudioStream()) return;

Expand Down Expand Up @@ -1237,4 +1317,40 @@ export default class Doodlebot {
this.pending[type] = undefined;
return value;
}

async uploadFile(type: string, blobURL: string) {
const tld = await this.topLevelDomain.promise;
let uploadEndpoint;
if (type == "sound") {
uploadEndpoint = "https://" + tld + "/api/v1/upload/sounds_upload";
} else {
uploadEndpoint = "https://" + tld + "/api/v1/upload/img_upload";
}

try {
const components = blobURL.split("---name---");
const response1 = await fetch(components[1]);
if (!response1.ok) {
throw new Error(`Failed to fetch Blob from URL: ${blobURL}`);
}
const blob = await response1.blob();
// Convert Blob to File
const file = new File([blob], components[0], { type: blob.type });
const formData = new FormData();
formData.append("file", file);

const response2 = await fetch(uploadEndpoint, {
method: "POST",
body: formData,
});


if (!response2.ok) {
throw new Error(`Failed to upload file: ${response2.statusText}`);
}

} catch (error) {
console.error("Error:", error);
}
}
}
212 changes: 212 additions & 0 deletions extensions/src/doodlebot/ModelUtils.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,212 @@
import tmPose from '@teachablemachine/pose';
import tmImage from '@teachablemachine/image';
import * as speechCommands from '@tensorflow-models/speech-commands';



export default class TeachableMachine {
latestAudioResults: any;
predictionState;
modelConfidences = {};
maxConfidence: number = null;
teachableImageModel;
isPredicting: number = 0;
ModelType = {
POSE: 'pose',
IMAGE: 'image',
AUDIO: 'audio',
};

constructor() {
this.predictionState = {};
}

useModel = async (url: string): Promise<{type: "success" | "error" | "warning", msg: string}> => {
try {
const modelUrl = this.modelArgumentToURL(url);
console.log('Loading model from URL:', modelUrl);

// Initialize prediction state if needed
this.predictionState[modelUrl] = {};

// Load and initialize the model
const { model, type } = await this.initModel(modelUrl);

this.predictionState[modelUrl].modelType = type;
this.predictionState[modelUrl].model = model;

// Update the current model reference
this.teachableImageModel = modelUrl;

return {
type: "success",
msg: "Model loaded successfully"
};
} catch (e) {
console.error('Error loading model:', e);
this.teachableImageModel = null;
return {
type: "error",
msg: "Failed to load model"
};
}
}


modelArgumentToURL = (modelArg: string) => {
// Convert user-provided model URL/ID to the correct format
const endpointProvidedFromInterface = "https://teachablemachine.withgoogle.com/models/";
const redirectEndpoint = "https://storage.googleapis.com/tm-model/";

return modelArg.startsWith(endpointProvidedFromInterface)
? modelArg.replace(endpointProvidedFromInterface, redirectEndpoint)
: redirectEndpoint + modelArg + "/";
}

initModel = async (modelUrl: string) => {
const avoidCache = `?x=${Date.now()}`;
const modelURL = modelUrl + "model.json" + avoidCache;
const metadataURL = modelUrl + "metadata.json" + avoidCache;

// First try loading as an image model
try {
const customMobileNet = await tmImage.load(modelURL, metadataURL);

// Check if it's actually an audio model
if ((customMobileNet as any)._metadata.hasOwnProperty('tfjsSpeechCommandsVersion')) {
const recognizer = await speechCommands.create("BROWSER_FFT", undefined, modelURL, metadataURL);
await recognizer.ensureModelLoaded();

// Setup audio listening
await recognizer.listen(async result => {
this.latestAudioResults = result;
}, {
includeSpectrogram: true,
probabilityThreshold: 0.75,
invokeCallbackOnNoiseAndUnknown: true,
overlapFactor: 0.50
});

return { model: recognizer, type: this.ModelType.AUDIO };
}
// Check if it's a pose model
else if ((customMobileNet as any)._metadata.packageName === "@teachablemachine/pose") {
const customPoseNet = await tmPose.load(modelURL, metadataURL);
return { model: customPoseNet, type: this.ModelType.POSE };
}
// Otherwise it's an image model
else {
return { model: customMobileNet, type: this.ModelType.IMAGE };
}
} catch (e) {
console.error("Failed to load model:", e);
throw e;
}
}

getPredictionFromModel = async (modelUrl: string, frame: ImageBitmap) => {
const { model, modelType } = this.predictionState[modelUrl];
switch (modelType) {
case this.ModelType.IMAGE:
if (!frame) return null;
return await model.predict(frame);
case this.ModelType.POSE:
if (!frame) return null;
const { pose, posenetOutput } = await model.estimatePose(frame);
return await model.predict(posenetOutput);
case this.ModelType.AUDIO:
if (this.latestAudioResults) {
return model.wordLabels().map((label, i) => ({
className: label,
probability: this.latestAudioResults.scores[i]
}));
}
return null;
}
}

private getPredictionStateOrStartPredicting(modelUrl: string) {
if (!modelUrl || !this.predictionState || !this.predictionState[modelUrl]) {
console.warn('No prediction state available for model:', modelUrl);
return null;
}
return this.predictionState[modelUrl];
}

model_match(state) {
const modelUrl = this.teachableImageModel;
const className = state;

const predictionState = this.getPredictionStateOrStartPredicting(modelUrl);
if (!predictionState) {
return false;
}

const currentMaxClass = predictionState.topClass;
return (currentMaxClass === String(className));
}

getModelClasses(): string[] {
if (
!this.teachableImageModel ||
!this.predictionState ||
!this.predictionState[this.teachableImageModel] ||
!this.predictionState[this.teachableImageModel].hasOwnProperty('model')
) {
return ["Select a class"];
}

if (this.predictionState[this.teachableImageModel].modelType === this.ModelType.AUDIO) {
return this.predictionState[this.teachableImageModel].model.wordLabels();
}

return this.predictionState[this.teachableImageModel].model.getClassLabels();
}

getModelPrediction() {
const modelUrl = this.teachableImageModel;
const predictionState: { topClass: string } = this.getPredictionStateOrStartPredicting(modelUrl);
if (!predictionState) {
console.error("No prediction state found");
return '';
}
return predictionState.topClass;
}

async predictAllBlocks(frame: ImageBitmap) {
for (let modelUrl in this.predictionState) {
if (!this.predictionState[modelUrl].model) {
console.log('No model found for:', modelUrl);
continue;
}
if (this.teachableImageModel !== modelUrl) {
console.log('Model URL mismatch:', modelUrl);
continue;
}
++this.isPredicting;
const prediction = await this.predictModel(modelUrl, frame);
this.predictionState[modelUrl].topClass = prediction;
--this.isPredicting;
}
}

private async predictModel(modelUrl: string, frame: ImageBitmap) {
const predictions = await this.getPredictionFromModel(modelUrl, frame);
if (!predictions) {
return;
}
let maxProbability = 0;
let maxClassName = "";
for (let i = 0; i < predictions.length; i++) {
const probability = predictions[i].probability.toFixed(2);
const className = predictions[i].className;
this.modelConfidences[className] = probability;
if (probability > maxProbability) {
maxClassName = className;
maxProbability = probability;
}
}
this.maxConfidence = maxProbability;
return maxClassName;
}
}
Loading
Loading