Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion client/src/components/DemoInterface.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,10 @@ interface AggregatedResult {
}

const DemoInterface: React.FC<DemoInterfaceProps> = ({ models, tokenInfo, apiBasePath }) => {
const curatedFileList = getCuratedFileList(tokenInfo.tapisHost);
const curatedFileList = useMemo(
() => getCuratedFileList(tokenInfo.tapisHost),
[tokenInfo.tapisHost]
);

// only supporting clip models on first pass
const clipModels = useMemo(
Expand Down
38 changes: 16 additions & 22 deletions client/src/hooks/useToken.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ const getHostFromIss = (iss: string): string => {
}
};

const fetchTokenFromPortal = async (): Promise<string | null> => {
const fetchTokenFromPortalUsingCookie = async (): Promise<string | null> => {
if (!isInIframe()) {
return null;
}
Expand Down Expand Up @@ -95,7 +95,7 @@ const validateTokenAndGetHost = async (token: string, fallbackHost: string): Pro
isValid: true,
};
} catch (error) {
console.error('Token validation error:', error);
console.error('Token missing or validation error:', error);
return {
token: '',
tapisHost: fallbackHost,
Expand All @@ -106,30 +106,14 @@ const validateTokenAndGetHost = async (token: string, fallbackHost: string): Pro

const getToken = async (fallbackHost: string): Promise<TokenInfo> => {
// First, try to get token from parent portal (if in iframe)
const tokenFromCorePortal = await fetchTokenFromPortal();
const tokenFromCorePortal = await fetchTokenFromPortalUsingCookie();

if (tokenFromCorePortal) {
// Validate the portal token
const result = await validateTokenAndGetHost(tokenFromCorePortal, fallbackHost);
if (result.isValid) {
// Store in sessionStorage for subsequent requests
sessionStorage.setItem('access_token', tokenFromCorePortal);

// Extract expiry from JWT
try {
const decoded = jwtDecode<TapisJwtPayload>(tokenFromCorePortal);
sessionStorage.setItem('expires_at', (decoded.exp * 1000).toString());
} catch {
console.warn('Failed to decode JWT for expiry');
// Fallback to 1 hour
sessionStorage.setItem('expires_at', (Date.now() + 3600000).toString());
}

return result;
}
return validateTokenAndGetHost(tokenFromCorePortal, fallbackHost);
}

// Fall back to sessionStorage (for direct access or cached token)
// Use sessionStorage (direct access, not in iframe scenario)
const token = sessionStorage.getItem('access_token');
const expiresAt = sessionStorage.getItem('expires_at');

Expand All @@ -144,13 +128,23 @@ const getToken = async (fallbackHost: string): Promise<TokenInfo> => {
return validateTokenAndGetHost(token, fallbackHost);
};

/**
* Manages the user's Tapis auth token via React Query.
*
* staleTime (5 min) ensures the backend middleware gets regular
* opportunities to refresh tokens before expiry (backend has 10
* min threshold).
*
* Retry disabled since auth failures require login, not retries.
*/
export const useToken = () => {
const config = useConfig();

return useQuery({
queryKey: ['token'],
queryFn: () => getToken(config.host),
retry: false,
refetchOnWindowFocus: false,
refetchOnWindowFocus: true,
staleTime: 5 * 60 * 1000 /* 5min */,
});
};
8 changes: 4 additions & 4 deletions imageinf/inference/clip_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,10 @@ def _precompute_text_features(self):
with torch.no_grad():
ti = self.processor(text=all_texts, return_tensors="pt", padding=True)
ti = {k: v.to(self.device) for k, v in ti.items()}
emb = self.model.get_text_features(
text_out = self.model.text_model(
input_ids=ti["input_ids"], attention_mask=ti["attention_mask"]
)
emb = self.model.text_projection(text_out.pooler_output)
emb = F.normalize(emb, dim=-1)
self.text_pairs = emb.reshape(len(self.labels), 2, -1)

Expand All @@ -99,9 +100,8 @@ def classify_image(
inputs = {k: v.to(self.device) for k, v in inputs.items()}

with torch.no_grad():
img_feat = self.model.get_image_features(
pixel_values=inputs["pixel_values"]
)
vision_out = self.model.vision_model(pixel_values=inputs["pixel_values"])
img_feat = self.model.visual_projection(vision_out.pooler_output)
img_feat = F.normalize(img_feat, dim=-1)

sims2 = torch.einsum("bd,lcd->blc", img_feat, self.text_pairs)
Expand Down