Skip to content

Add depth estimation widget #953

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import AudioToAudioWidget from "./widgets/AudioToAudioWidget/AudioToAudioWidget.svelte";
import AutomaticSpeechRecognitionWidget from "./widgets/AutomaticSpeechRecognitionWidget/AutomaticSpeechRecognitionWidget.svelte";
import ConversationalWidget from "./widgets/ConversationalWidget/ConversationalWidget.svelte";
import DephthEstimationWidget from "./widgets/DephthEstimationWidget/DephthEstimationWidget.svelte";
import FeatureExtractionWidget from "./widgets/FeatureExtractionWidget/FeatureExtractionWidget.svelte";
import FillMaskWidget from "./widgets/FillMaskWidget/FillMaskWidget.svelte";
import ImageClassificationWidget from "./widgets/ImageClassificationWidget/ImageClassificationWidget.svelte";
Expand Down Expand Up @@ -51,6 +52,7 @@
"audio-classification": AudioClassificationWidget,
"automatic-speech-recognition": AutomaticSpeechRecognitionWidget,
conversational: ConversationalWidget,
"depth-estimation": DephthEstimationWidget,
"feature-extraction": FeatureExtractionWidget,
"fill-mask": FillMaskWidget,
"image-classification": ImageClassificationWidget,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
<script lang="ts">
import { afterUpdate } from "svelte";

export let classNames = "";
export let imgSrc = "";
export let depthMap: ImageData | null = null;

let containerEl: HTMLElement;
let canvas: HTMLCanvasElement;
let imgEl: HTMLImageElement;
let width = 0;
let height = 0;

function draw() {
width = containerEl.clientWidth;
height = containerEl.clientHeight;
const ctx = canvas?.getContext("2d");

if (ctx && imgEl && depthMap) {
ctx.drawImage(imgEl, 0, 0, width, height);
const imageData = ctx.getImageData(0, 0, width, height);

for (let i = 0; i < imageData.data.length; i += 4) {
const depth = depthMap.data[i];
imageData.data[i] = depth;
imageData.data[i + 1] = depth;
imageData.data[i + 2] = depth;
}

ctx.putImageData(imageData, 0, 0);
}
}

afterUpdate(draw);
</script>

<svelte:window on:resize={draw} />

<div class="relative top-0 left-0 inline-flex {classNames}" bind:this={containerEl}>
<div class="flex max-w-sm justify-center">
<img alt="" class="relative top-0 left-0 object-contain" src={imgSrc} bind:this={imgEl} />
</div>
{#if depthMap}
<canvas
class="absolute top-0 left-0"
{width}
{height}
bind:this={canvas}
/>
{/if}
</div>
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
<script lang="ts">
import type { WidgetProps, ExampleRunOpts, InferenceRunOpts } from "../../shared/types.js";
import type { WidgetExampleAssetInput } from "@huggingface/tasks";

import { onMount } from "svelte";

import WidgetFileInput from "../../shared/WidgetFileInput/WidgetFileInput.svelte";
import WidgetDropzone from "../../shared/WidgetDropzone/WidgetDropzone.svelte";
import WidgetWrapper from "../../shared/WidgetWrapper/WidgetWrapper.svelte";
import { callInferenceApi, getBlobFromUrl } from "../../shared/helpers.js";
import { isAssetInput } from "../../shared/inputValidation.js";
import { widgetStates } from "../../stores.js";

import Canvas from "./Canvas.svelte";

export let apiToken: WidgetProps["apiToken"];
export let apiUrl: WidgetProps["apiUrl"];
export let callApiOnMount: WidgetProps["callApiOnMount"];
export let model: WidgetProps["model"];
export let noTitle: WidgetProps["noTitle"];
export let includeCredentials: WidgetProps["includeCredentials"];

$: isDisabled = $widgetStates?.[model.id]?.isDisabled;

let computeTime = "";
let error: string = "";
let isLoading = false;
let imgSrc = "";
let modelLoading = {
isLoading: false,
estimatedTime: 0,
};
let output: ImageData | null = null;
let outputJson: string;
let warning: string = "";

function onSelectFile(file: File | Blob) {
imgSrc = URL.createObjectURL(file);
getOutput(file);
}

async function getOutput(
file: File | Blob,
{ withModelLoading = false, isOnLoadCall = false, exampleOutput = undefined }: InferenceRunOpts = {}
) {
if (!file) {
return;
}

// Reset values
computeTime = "";
error = "";
warning = "";
output = null;
outputJson = "";

const requestBody = { file };

isLoading = true;

const res = await callInferenceApi(
apiUrl,
model.id,
requestBody,
apiToken,
parseOutput,
withModelLoading,
includeCredentials,
isOnLoadCall
);

isLoading = false;
modelLoading = { isLoading: false, estimatedTime: 0 };

if (res.status === "success") {
computeTime = res.computeTime;
output = res.output;
outputJson = res.outputJson;
} else if (res.status === "loading-model") {
modelLoading = {
isLoading: true,
estimatedTime: res.estimatedTime,
};
getOutput(file, { withModelLoading: true });
} else if (res.status === "error" && !isOnLoadCall) {
error = res.error;
}
}

function parseOutput(body: unknown): ImageData {
if (body instanceof ImageData) {
return body;
}
throw new TypeError("Invalid output: output must be of type ImageData");
}

async function applyWidgetExample(sample: WidgetExampleAssetInput, opts: ExampleRunOpts = {}) {
imgSrc = sample.src;
if (opts.isPreview) {
output = null;
outputJson = "";
return;
}
const blob = await getBlobFromUrl(imgSrc);
const exampleOutput = sample.output;
getOutput(blob, { ...opts.inferenceOpts, exampleOutput });
}

onMount(() => {
if (callApiOnMount) {
getOutput(new Blob(), { isOnLoadCall: true });
}
});
</script>

<WidgetWrapper {apiUrl} {includeCredentials} {model} let:WidgetInfo let:WidgetHeader let:WidgetFooter>
<WidgetHeader
{noTitle}
{model}
{isLoading}
{isDisabled}
{callApiOnMount}
{applyWidgetExample}
validateExample={isAssetInput}
/>

<WidgetDropzone
classNames="hidden md:block"
{isLoading}
{isDisabled}
{imgSrc}
on:run={(e) => onSelectFile(e.detail)}
on:error={(e) => (error = e.detail)}
>
{#if imgSrc}
<Canvas {imgSrc} depthMap={output} />
{/if}
</WidgetDropzone>
<!-- Better UX for mobile/table through CSS breakpoints -->
{#if imgSrc}
<Canvas classNames="mr-2 md:hidden" {imgSrc} depthMap={output} />
{/if}
<WidgetFileInput
accept="image/*"
classNames="mr-2 md:hidden"
{isLoading}
{isDisabled}
label="Browse for image"
on:run={(e) => onSelectFile(e.detail)}
/>
{#if warning}
<div class="alert alert-warning mt-2">{warning}</div>
{/if}

<WidgetInfo {model} {computeTime} {error} {modelLoading} />

<WidgetFooter {model} {isDisabled} {outputJson} />
</WidgetWrapper>
Loading