diff --git a/scripts/smoke-generation-mocked.mjs b/scripts/smoke-generation-mocked.mjs index f078568..ddf1a75 100644 --- a/scripts/smoke-generation-mocked.mjs +++ b/scripts/smoke-generation-mocked.mjs @@ -42,9 +42,9 @@ assertNoMatch( /dashscope\.aliyuncs\.com|\/dashscope-api\b|Bearer\s+sk-/i, ); assertMatch("image generation must go through the app API", generationClient, /buildApiUrl\("ai\/image"\)/); -assertMatch("video generation must go through the app API", generationClient, /buildApiUrl\("ai\/video"\)/); +assertMatch("video generation must go through the app API", generationClient, /serverRequest<\{ taskId: string \}>\("ai\/video"/); assertMatch("binary uploads must go through the app OSS API", generationClient, /buildApiUrl\("oss\/upload-binary"\)/); -assertMatch("URL uploads must go through the app OSS API", generationClient, /buildApiUrl\("oss\/upload-by-url"\)/); +assertMatch("URL uploads must go through the app OSS API", generationClient, /serverRequest<\{ url: string; signedUrl\?: string; ossKey\?: string \}>\("oss\/upload-by-url"/); assertMatch( "ecommerce video history must durable-copy media before saving", ecommerceVideoService, diff --git a/src/api/aiGenerationClient.ts b/src/api/aiGenerationClient.ts index accb697..5e78c26 100644 --- a/src/api/aiGenerationClient.ts +++ b/src/api/aiGenerationClient.ts @@ -3,6 +3,7 @@ import { buildAuthHeaders, isRecord, readJsonResponse, + serverRequest, throwResponseError, } from "./serverConnection"; import { isOptionalApiRouteMissing } from "./apiErrorUtils"; @@ -243,6 +244,10 @@ function emitImageRouteDebug(label: string, payload: Record): v let taskHistoryRouteMissing = false; +const TASK_SUBMIT_TIMEOUT_MS = 90_000; +const TASK_STATUS_TIMEOUT_MS = 20_000; +const NON_RETRYING_REQUEST = { maxRetries: 0 }; + export const aiGenerationClient = { async createImageTask(input: ImageGenInput): Promise { const requestUrl = buildApiUrl("ai/image"); @@ -256,15 +261,13 @@ export const aiGenerationClient = { projectId: input.projectId, conversationId: input.conversationId, }); - const res = await fetch(requestUrl, { + const payload = await serverRequest("ai/image", { method: "POST", - headers: buildAuthHeaders(), - body: JSON.stringify(input), + body: input, + timeoutMs: TASK_SUBMIT_TIMEOUT_MS, + maxRetries: NON_RETRYING_REQUEST.maxRetries, + fallbackMessage: "Image generation request failed", }); - if (!res.ok) { - await throwResponseError(res, "Image generation request failed"); - } - const payload = await readJsonResponse(res, "Image generation response failed"); if (payload.providerDebug) { emitImageRouteDebug("[ai/image-provider-debug]", payload.providerDebug as Record); } @@ -272,96 +275,83 @@ export const aiGenerationClient = { }, async createVideoTask(input: VideoGenInput): Promise<{ taskId: string }> { - const res = await fetch(buildApiUrl("ai/video"), { + return serverRequest<{ taskId: string }>("ai/video", { method: "POST", - headers: buildAuthHeaders(), - body: JSON.stringify(input), + body: input, + timeoutMs: TASK_SUBMIT_TIMEOUT_MS, + maxRetries: NON_RETRYING_REQUEST.maxRetries, + fallbackMessage: "Video generation request failed", }); - if (!res.ok) { - await throwResponseError(res, "Video generation request failed"); - } - return readJsonResponse<{ taskId: string }>(res, "Video generation response failed"); }, async createVideoSuperResolveTask(input: VideoSuperResolveInput): Promise<{ taskId: string }> { - const res = await fetch(buildApiUrl("ai/video/super-resolve"), { + return serverRequest<{ taskId: string }>("ai/video/super-resolve", { method: "POST", - headers: buildAuthHeaders(), - body: JSON.stringify(input), + body: input, + timeoutMs: TASK_SUBMIT_TIMEOUT_MS, + maxRetries: NON_RETRYING_REQUEST.maxRetries, + fallbackMessage: "Video super-resolution request failed", }); - if (!res.ok) { - await throwResponseError(res, "Video super-resolution request failed"); - } - return readJsonResponse<{ taskId: string }>(res, "Video super-resolution response failed"); }, async createEraseSubtitlesTask(input: EraseSubtitlesInput): Promise<{ taskId: string }> { - const res = await fetch(buildApiUrl("ai/video/erase-subtitles"), { + return serverRequest<{ taskId: string }>("ai/video/erase-subtitles", { method: "POST", - headers: buildAuthHeaders(), - body: JSON.stringify(input), + body: input, + timeoutMs: TASK_SUBMIT_TIMEOUT_MS, + maxRetries: NON_RETRYING_REQUEST.maxRetries, + fallbackMessage: "Subtitle removal request failed", }); - if (!res.ok) { - await throwResponseError(res, "Subtitle removal request failed"); - } - return readJsonResponse<{ taskId: string }>(res, "Subtitle removal response failed"); }, async createVideoEditTask(input: VideoEditInput): Promise<{ taskId: string }> { - const res = await fetch(buildApiUrl("ai/video/edit"), { + return serverRequest<{ taskId: string }>("ai/video/edit", { method: "POST", - headers: buildAuthHeaders(), - body: JSON.stringify({ ...input, model: input.model || "happyhorse-1.0-video-edit" }), + body: { ...input, model: input.model || "happyhorse-1.0-video-edit" }, + timeoutMs: TASK_SUBMIT_TIMEOUT_MS, + maxRetries: NON_RETRYING_REQUEST.maxRetries, + fallbackMessage: "Video edit request failed", }); - if (!res.ok) { - await throwResponseError(res, "Video edit request failed"); - } - return readJsonResponse<{ taskId: string }>(res, "Video edit response failed"); }, async createImageSuperResolveTask(input: ImageSuperResolveInput): Promise<{ taskId: string }> { - const res = await fetch(buildApiUrl("ai/image/super-resolve"), { + return serverRequest<{ taskId: string }>("ai/image/super-resolve", { method: "POST", - headers: buildAuthHeaders(), - body: JSON.stringify(input), + body: input, + timeoutMs: TASK_SUBMIT_TIMEOUT_MS, + maxRetries: NON_RETRYING_REQUEST.maxRetries, + fallbackMessage: "Image super-resolution request failed", }); - if (!res.ok) { - await throwResponseError(res, "Image super-resolution request failed"); - } - return readJsonResponse<{ taskId: string }>(res, "Image super-resolution response failed"); }, async createImageEditTask(input: ImageEditInput): Promise<{ taskId: string }> { - const res = await fetch(buildApiUrl("ai/image/edit"), { + return serverRequest<{ taskId: string }>("ai/image/edit", { method: "POST", - headers: buildAuthHeaders(), - body: JSON.stringify(input), + body: input, + timeoutMs: TASK_SUBMIT_TIMEOUT_MS, + maxRetries: NON_RETRYING_REQUEST.maxRetries, + fallbackMessage: "Image edit request failed", }); - if (!res.ok) { - await throwResponseError(res, "Image edit request failed"); - } - return readJsonResponse<{ taskId: string }>(res, "Image edit response failed"); }, async cancelTask(taskId: string): Promise { - const res = await fetch(buildApiUrl(`ai/tasks/${taskId}/cancel`), { - method: "PATCH", - headers: buildAuthHeaders(), - }); - if (!res.ok && res.status !== 404) { - await throwResponseError(res, "Task cancel failed"); + try { + await serverRequest(`ai/tasks/${taskId}/cancel`, { + method: "PATCH", + maxRetries: NON_RETRYING_REQUEST.maxRetries, + fallbackMessage: "Task cancel failed", + }); + } catch (error) { + if (isOptionalApiRouteMissing(error)) return; + throw error; } }, async getTaskStatus(taskId: string): Promise { - const res = await fetch(buildApiUrl(`ai/tasks/${taskId}`), { - method: "GET", - headers: buildAuthHeaders(), + return serverRequest(`ai/tasks/${taskId}`, { + timeoutMs: TASK_STATUS_TIMEOUT_MS, + fallbackMessage: "Task status request failed", }); - if (!res.ok) { - await throwResponseError(res, "Task status request failed"); - } - return readJsonResponse(res, "Task status response failed"); }, async downloadTaskResult(taskId: string): Promise<{ blob: Blob; filename?: string; contentType?: string }> { @@ -387,49 +377,41 @@ export const aiGenerationClient = { if (params?.status) search.set("status", params.status); if (params?.type) search.set("type", params.type); if (params?.projectId) search.set("projectId", params.projectId); - const res = await fetch(buildApiUrl(`ai/tasks${search.toString() ? `?${search}` : ""}`), { - method: "GET", - headers: buildAuthHeaders(), - }); - if (!res.ok) { - try { - await throwResponseError(res, "Task history request failed"); - } catch (error) { - if (isOptionalApiRouteMissing(error)) { - taskHistoryRouteMissing = true; - return []; - } - throw error; + try { + const payload = await serverRequest(`ai/tasks${search.toString() ? `?${search}` : ""}`, { + fallbackMessage: "Task history request failed", + }); + return extractTaskList(payload).map(toPreviewTask); + } catch (error) { + if (isOptionalApiRouteMissing(error)) { + taskHistoryRouteMissing = true; + return []; } + throw error; } - const payload = await readJsonResponse(res, "Task history response failed"); - return extractTaskList(payload).map(toPreviewTask); }, async bindTaskToConversation(taskId: string, conversationId: number): Promise { - const res = await fetch(buildApiUrl(`ai/tasks/${taskId}/conversation`), { - method: "PATCH", - headers: buildAuthHeaders(), - body: JSON.stringify({ conversationId }), - }); - if (res.status === 404) { - return; - } - if (!res.ok) { - await throwResponseError(res, "Task conversation binding failed"); + try { + await serverRequest(`ai/tasks/${taskId}/conversation`, { + method: "PATCH", + body: { conversationId }, + maxRetries: NON_RETRYING_REQUEST.maxRetries, + fallbackMessage: "Task conversation binding failed", + }); + } catch (error) { + if (isOptionalApiRouteMissing(error)) return; + throw error; } }, async uploadAsset(input: UploadAssetInput): Promise<{ url: string; signedUrl?: string; ossKey?: string }> { - const res = await fetch(buildApiUrl("oss/upload"), { + return serverRequest<{ url: string; signedUrl?: string; ossKey?: string }>("oss/upload", { method: "POST", - headers: buildAuthHeaders(), - body: JSON.stringify(input), + body: input, + maxRetries: NON_RETRYING_REQUEST.maxRetries, + fallbackMessage: "Asset upload failed", }); - if (!res.ok) { - await throwResponseError(res, "Asset upload failed"); - } - return readJsonResponse<{ url: string; ossKey?: string }>(res, "Asset upload response failed"); }, async uploadAssetBinary(blob: Blob, options?: { name?: string; mimeType?: string; scope?: string }): Promise<{ url: string; signedUrl?: string; ossKey?: string }> { @@ -451,15 +433,12 @@ export const aiGenerationClient = { }, async uploadAssetByUrl(input: UploadAssetByUrlInput): Promise<{ url: string; signedUrl?: string; ossKey?: string }> { - const res = await fetch(buildApiUrl("oss/upload-by-url"), { + return serverRequest<{ url: string; signedUrl?: string; ossKey?: string }>("oss/upload-by-url", { method: "POST", - headers: buildAuthHeaders(), - body: JSON.stringify(input), + body: input, + maxRetries: NON_RETRYING_REQUEST.maxRetries, + fallbackMessage: "Asset upload by URL failed", }); - if (!res.ok) { - await throwResponseError(res, "Asset upload by URL failed"); - } - return readJsonResponse<{ url: string; ossKey?: string }>(res, "Asset upload by URL response failed"); }, subscribeTaskStatus( diff --git a/src/api/providerHealthClient.ts b/src/api/providerHealthClient.ts index 11c48b4..6552661 100644 --- a/src/api/providerHealthClient.ts +++ b/src/api/providerHealthClient.ts @@ -1,4 +1,4 @@ -import { buildApiUrl, buildAuthHeaders } from "./serverConnection"; +import { serverRequest } from "./serverConnection"; export interface ProviderHealthEntry { status: string; @@ -32,13 +32,8 @@ export interface ProviderHealthResponse { export const providerHealthClient = { async getStatus(): Promise { - const res = await fetch(buildApiUrl("admin/providers/status"), { - method: "GET", - headers: buildAuthHeaders(), + return serverRequest("admin/providers/status", { + fallbackMessage: "Provider health request failed", }); - if (!res.ok) { - throw new Error(`Provider health request failed (${res.status})`); - } - return res.json() as Promise; }, -}; \ No newline at end of file +}; diff --git a/src/api/scriptEvalClient.ts b/src/api/scriptEvalClient.ts index d593541..3afa296 100644 --- a/src/api/scriptEvalClient.ts +++ b/src/api/scriptEvalClient.ts @@ -1,4 +1,4 @@ -import { buildApiUrl, buildAuthHeaders } from "./serverConnection"; +import { serverRequest } from "./serverConnection"; export interface ScriptEvalResult { totalScore: number; @@ -140,10 +140,13 @@ function normalizeEvidence(value: unknown): Record { } export async function evaluateScript(script: string, signal?: AbortSignal): Promise { - const res = await fetch(buildApiUrl("ai/chat"), { + const payload = await serverRequest<{ + content?: string; + choices?: Array<{ message?: { content?: string } }>; + text?: string; + }>("ai/chat", { method: "POST", - headers: buildAuthHeaders(), - body: JSON.stringify({ + body: { model: MODEL, messages: [ { role: "system", content: EVAL_SYSTEM_PROMPT }, @@ -153,16 +156,13 @@ export async function evaluateScript(script: string, signal?: AbortSignal): Prom stream: false, temperature: 0.3, max_tokens: 4096, - }), + }, signal, + timeoutMs: 180_000, + maxRetries: 0, + fallbackMessage: "评测请求失败", }); - if (!res.ok) { - const errText = await res.text().catch(() => ""); - throw new Error(`评测请求失败 (${res.status}): ${errText.slice(0, 200)}`); - } - - const payload = await res.json(); const content: string = payload?.content ?? payload?.choices?.[0]?.message?.content ?? payload?.text ?? ""; if (!content) throw new Error("模型未返回有效内容"); diff --git a/src/api/serverConnection.ts b/src/api/serverConnection.ts index d1302fa..3c2a499 100644 --- a/src/api/serverConnection.ts +++ b/src/api/serverConnection.ts @@ -22,6 +22,9 @@ export interface ServerRequestOptions { signal?: AbortSignal; /** Per-request timeout in ms. Defaults to DEFAULT_REQUEST_TIMEOUT_MS. Pass 0 to disable. */ timeoutMs?: number; + /** Defaults to 2. Use 0 for non-idempotent task submission endpoints. */ + maxRetries?: number; + fallbackMessage?: string; } export const DEFAULT_REQUEST_TIMEOUT_MS = 30_000; @@ -343,8 +346,10 @@ const MAX_RETRIES = 2; export async function serverRequest(path: string, options?: ServerRequestOptions): Promise { let lastError: unknown; const timeoutMs = options?.timeoutMs ?? DEFAULT_REQUEST_TIMEOUT_MS; + const maxRetries = options?.maxRetries ?? MAX_RETRIES; + const fallbackMessage = options?.fallbackMessage || "Request failed"; - for (let attempt = 0; attempt <= MAX_RETRIES; attempt++) { + for (let attempt = 0; attempt <= maxRetries; attempt++) { const controller = timeoutMs > 0 ? new AbortController() : null; const timeoutId = controller && typeof window !== "undefined" @@ -366,11 +371,11 @@ export async function serverRequest(path: string, options?: ServerRequestOpti credentials: "include", }); - const payload = await readJsonResponse(response, "Request failed"); + const payload = await readJsonResponse(response, fallbackMessage); return (options?.raw ? payload : unwrapApiPayload(payload)) as T; } catch (error) { lastError = error; - if (attempt < MAX_RETRIES && isRetryable(error) && !options?.signal?.aborted) { + if (attempt < maxRetries && isRetryable(error) && !options?.signal?.aborted) { await new Promise((r) => setTimeout(r, getRetryDelay(attempt, error))); continue; } diff --git a/src/features/canvas/CanvasPage.tsx b/src/features/canvas/CanvasPage.tsx index 1d99981..2ecf3b2 100644 --- a/src/features/canvas/CanvasPage.tsx +++ b/src/features/canvas/CanvasPage.tsx @@ -32,6 +32,7 @@ import { useCallback, useEffect, useMemo, useRef, useState, type ChangeEvent, ty import { aiGenerationClient } from "../../api/aiGenerationClient"; import { assetClient, type ServerAssetItem } from "../../api/assetClient"; import { communityClient } from "../../api/communityClient"; +import { modelCapabilitiesClient } from "../../api/modelCapabilitiesClient"; import type { CreatePreviewTaskInput } from "../../api/webGenerationGateway"; import WorkspacePageShell from "../../components/WorkspacePageShell"; import type { @@ -118,7 +119,7 @@ import { defaultVideoModel, image4kCapableModels, imageFocusRatioOptions, - imageModelOptions, + imageModelOptions as fallbackCanvasImageModelOptions, imageRatioOptions, textModelOptions, videoDurationOptions, @@ -354,6 +355,8 @@ function CanvasPage({ const [projectNameEditing, setProjectNameEditing] = useState(false); const [videoNodeMenu, setVideoNodeMenu] = useState<{ left: number; top: number; nodeId: string } | null>(null); const [videoNodes, setVideoNodes] = useState([]); + const [canvasImageModelOptions, setCanvasImageModelOptions] = useState(fallbackCanvasImageModelOptions); + const [canvasVideoModelOptions, setCanvasVideoModelOptions] = useState(canvasEnterpriseVideoModelOptions); const [selectedNode, setSelectedNode] = useState(null); const [selectedNodes, setSelectedNodes] = useState([]); const [selectionContextMenu, setSelectionContextMenu] = useState(null); @@ -458,9 +461,39 @@ function CanvasPage({ callbacksRef: dragCallbacksRef, suppressNextPaneClickRef, }); + + useEffect(() => { + let cancelled = false; + + if (!isAuthenticated) { + setCanvasImageModelOptions(fallbackCanvasImageModelOptions); + setCanvasVideoModelOptions(canvasEnterpriseVideoModelOptions); + return () => { + cancelled = true; + }; + } + + modelCapabilitiesClient + .get() + .then((capabilities) => { + if (cancelled) return; + setCanvasImageModelOptions(capabilities.imageModels.length ? capabilities.imageModels : fallbackCanvasImageModelOptions); + setCanvasVideoModelOptions(capabilities.videoModels.length ? capabilities.videoModels : canvasEnterpriseVideoModelOptions); + }) + .catch(() => { + if (cancelled) return; + setCanvasImageModelOptions(fallbackCanvasImageModelOptions); + setCanvasVideoModelOptions(canvasEnterpriseVideoModelOptions); + }); + + return () => { + cancelled = true; + }; + }, [isAuthenticated]); + const visibleImageModelOptions = useMemo( - () => filterImageModelOptionsForSession(imageModelOptions, session), - [session], + () => filterImageModelOptionsForSession(canvasImageModelOptions, session), + [canvasImageModelOptions, session], ); const fallbackVisibleImageModel = visibleImageModelOptions[0]?.value || defaultImageModel; const resolveVisibleImageModel = useCallback( @@ -5044,7 +5077,7 @@ function CanvasPage({ ariaLabel="选择视频模型" className="canvas-select-chip--model studio-canvas-composer-chip" value={toHappyHorseDisplayModel(videoNode.model || defaultVideoModel)} - options={canvasEnterpriseVideoModelOptions} + options={canvasVideoModelOptions} open={canvasSelectMenu === `${videoNode.id}:video-model`} onToggle={() => setCanvasSelectMenu((current) => diff --git a/src/features/ecommerce/ecommerceVideoService.ts b/src/features/ecommerce/ecommerceVideoService.ts index e06f82e..9f104ad 100644 --- a/src/features/ecommerce/ecommerceVideoService.ts +++ b/src/features/ecommerce/ecommerceVideoService.ts @@ -9,6 +9,7 @@ import { type AdVideoUserConfig, } from "../../api/adVideoPlanClient"; import { aiGenerationClient } from "../../api/aiGenerationClient"; +import { serverRequest } from "../../api/serverConnection"; import { waitForTask } from "../../api/taskSubscription"; import { resolveVideoRequestModel } from "../../utils/resolveVideoModel"; import { normalizeEcommerceImageMime } from "./ecommerceImageValidation"; @@ -430,15 +431,6 @@ export interface VideoHistoryListResponse { offset: number; } -import { getStoredToken } from "../../api/serverConnection"; - -const API_BASE = "/api/ai/ecommerce/video-history"; - -function getAuthHeaders(): Record { - const token = getStoredToken(); - return token ? { Authorization: `Bearer ${token}` } : {}; -} - export async function buildDurableVideoHistoryPayload(payload: SaveVideoHistoryPayload): Promise { const uploadAssetByUrl = payload.uploadAssetByUrl; const scenes = await Promise.all( @@ -486,13 +478,12 @@ export async function buildDurableVideoHistoryPayload(payload: SaveVideoHistoryP export async function saveVideoHistory(payload: SaveVideoHistoryPayload): Promise<{ id: number; createdAt: string }> { const { uploadAssetByUrl: _uploadAssetByUrl, ...historyPayload } = await buildDurableVideoHistoryPayload(payload); - const res = await fetch(API_BASE, { + return serverRequest<{ id: number; createdAt: string }>("ai/ecommerce/video-history", { method: "POST", - headers: { "Content-Type": "application/json", ...getAuthHeaders() }, - body: JSON.stringify(historyPayload), + body: historyPayload, + maxRetries: 0, + fallbackMessage: "Failed to save video history", }); - if (!res.ok) throw new Error("Failed to save video history"); - return res.json(); } function removeTemporaryHistoryUrls(item: VideoHistoryItem): VideoHistoryItem { @@ -511,12 +502,10 @@ export async function fetchVideoHistory( limit = 20, offset = 0, ): Promise { - const res = await fetch( - `${API_BASE}?limit=${limit}&offset=${offset}`, - { headers: getAuthHeaders() }, - ); - if (!res.ok) throw new Error("Failed to fetch video history"); - const history = (await res.json()) as VideoHistoryListResponse; + const search = new URLSearchParams({ limit: String(limit), offset: String(offset) }); + const history = await serverRequest(`ai/ecommerce/video-history?${search}`, { + fallbackMessage: "Failed to fetch video history", + }); return { ...history, items: history.items.map(removeTemporaryHistoryUrls), @@ -524,9 +513,9 @@ export async function fetchVideoHistory( } export async function deleteVideoHistory(id: number): Promise { - const res = await fetch(`${API_BASE}/${id}`, { + await serverRequest(`ai/ecommerce/video-history/${id}`, { method: "DELETE", - headers: getAuthHeaders(), + maxRetries: 0, + fallbackMessage: "Failed to delete video history", }); - if (!res.ok) throw new Error("Failed to delete video history"); } diff --git a/src/features/image-workbench/ImageWorkbenchPage.tsx b/src/features/image-workbench/ImageWorkbenchPage.tsx index eed76d3..a7f084f 100644 --- a/src/features/image-workbench/ImageWorkbenchPage.tsx +++ b/src/features/image-workbench/ImageWorkbenchPage.tsx @@ -152,6 +152,7 @@ function ImageWorkbenchPage({ initialTool = "workbench", onOpenMore, onSelectVie abortRef.current = false; taskIdRef.current = saved.taskId; void waitForTask(saved.taskId, { + kind: "image", onProgress: (e) => { setStatus(`${e.status} / ${e.progress}%`); if (e.status === "completed" && e.resultUrl) { @@ -446,6 +447,7 @@ function ImageWorkbenchPage({ initialTool = "workbench", onOpenMore, onSelectVie const pollTaskUntilDone = useCallback(async (taskId: string): Promise => { return waitForTask(taskId, { + kind: "image", abortRef, onProgress: (e) => setGenerationProgress(e.progress || 0), }); @@ -559,7 +561,7 @@ function ImageWorkbenchPage({ initialTool = "workbench", onOpenMore, onSelectVie referenceUrls: refUrls, }); taskIdRef.current = taskId; - saveToolTaskState("imagewb", { taskId, status: "running", progress: 0 }); + saveToolTaskState("imagewb", { taskId, status: "running", progress: 0 }); const tempUrl = await pollTaskUntilDone(taskId); if (tempUrl) { diff --git a/src/features/workbench/WorkbenchPage.tsx b/src/features/workbench/WorkbenchPage.tsx index f40ff7e..8800fa4 100644 --- a/src/features/workbench/WorkbenchPage.tsx +++ b/src/features/workbench/WorkbenchPage.tsx @@ -369,7 +369,7 @@ function WorkbenchPage({ .get() .then((capabilities) => { if (cancelled) return; - const nextVideoModels = VIDEO_MODEL_OPTIONS; + const nextVideoModels = capabilities.videoModels.length ? capabilities.videoModels : VIDEO_MODEL_OPTIONS; applyImageModels(capabilities.imageModels); setVideoModelOptions(nextVideoModels); diff --git a/src/features/workbench/toolKeepalive.ts b/src/features/workbench/toolKeepalive.ts index 40140a6..904a97f 100644 --- a/src/features/workbench/toolKeepalive.ts +++ b/src/features/workbench/toolKeepalive.ts @@ -3,6 +3,8 @@ * Persists task state to localStorage so in-progress tasks survive page switches. */ +import { waitForTask } from "../../api/taskSubscription"; + const KEEPALIVE_PREFIX = "omniai:tool-task:"; interface ToolTaskKeepalive { @@ -59,38 +61,19 @@ export function clearToolTaskState(key: string): void { try { window.localStorage.removeItem(KEEPALIVE_PREFIX + key); } catch { /* ignore */ } } -const TASK_POLL_INTERVAL = 3000; -const TASK_POLL_TIMEOUT = 30 * 60 * 1000; - export async function pollTaskUntilDone( taskId: string, onProgress?: (progress: number) => void, abortRef?: { current: boolean }, + kind: "image" | "video" = "video", ): Promise { - const startTime = Date.now(); - const { aiGenerationClient } = await import("../../api/aiGenerationClient"); - - while (true) { - if (abortRef?.current) return null; - if (Date.now() - startTime > TASK_POLL_TIMEOUT) return null; - - try { - const task = await aiGenerationClient.getTaskStatus(taskId); - if (!task) return null; - - const progress = Math.min(99, task.progress || 0); - onProgress?.(progress); - - if (task.status === "completed") { - return task.resultUrl || null; - } - if (task.status === "failed" || task.status === "cancelled") { - return null; - } - } catch { - // retry on next poll - } - - await new Promise((r) => setTimeout(r, TASK_POLL_INTERVAL)); + try { + return await waitForTask(taskId, { + kind, + abortRef, + onProgress: (event) => onProgress?.(Math.min(99, Number(event.progress || 0))), + }); + } catch { + return null; } } diff --git a/src/services/backgroundTaskRunner.ts b/src/services/backgroundTaskRunner.ts index 15b235f..bdae615 100644 --- a/src/services/backgroundTaskRunner.ts +++ b/src/services/backgroundTaskRunner.ts @@ -1,20 +1,12 @@ import { useGenerationStore, type GenerationQueueItem } from "../stores/useGenerationStore"; -import { aiGenerationClient } from "../api/aiGenerationClient"; -import { - buildLocalTimeoutMessage, - buildTaskFailureInfo, - getTaskTimeoutPolicy, - isTaskLocallyTimedOut, -} from "../utils/taskLifecycle"; +import { waitForTask, type TaskProgressEvent } from "../api/taskSubscription"; +import { buildTaskFailureInfo } from "../utils/taskLifecycle"; type PollCallback = (item: GenerationQueueItem) => void; -const activePollers = new Map>(); +const activePollers = new Map(); const pollCallbacks = new Set(); -const POLL_INTERVAL = 3000; -const MAX_POLL_ATTEMPTS = 200; // Keep the previous 10-minute guard as a fallback. - export function subscribeToTaskUpdates(callback: PollCallback): () => void { pollCallbacks.add(callback); return () => { pollCallbacks.delete(callback); }; @@ -34,109 +26,109 @@ function getQueueItemModel(item: GenerationQueueItem): string | undefined { return typeof item.params?.model === "string" ? item.params.model : undefined; } -function pollTask(item: GenerationQueueItem, attemptsRef: { current: number }): void { - const key = `poll-${item.id}`; - if (activePollers.has(key)) return; - - const kind = getQueueItemKind(item); - const timeoutPolicy = getTaskTimeoutPolicy({ kind, model: getQueueItemModel(item) }); - let lastProgress = Math.max(0, Number(item.progress || 0)); - let lastProgressAt = Date.now(); - - const interval = setInterval(async () => { - const current = useGenerationStore.getState().queue.find((i) => i.id === item.id); - if (!current || current.status === "completed" || current.status === "failed" || current.status === "cancelled") { - cleanupPoll(key); - return; - } - - attemptsRef.current++; - const timeoutReason = isTaskLocallyTimedOut({ - startedAt: current.createdAt || item.createdAt || Date.now(), - lastProgressAt, - progress: lastProgress, - policy: timeoutPolicy, - }); - if (timeoutReason || attemptsRef.current > MAX_POLL_ATTEMPTS) { - const error = buildLocalTimeoutMessage(kind); - useGenerationStore.getState().updateTask(item.id, { - status: "failed", - error, - }); - notifyCallbacks({ ...item, status: "failed", error }); - cleanupPoll(key); - return; - } - - try { - const status = await aiGenerationClient.getTaskStatus(current.taskId || item.taskId || ""); - const nextProgress = Number(status.progress || 0); - if (nextProgress > lastProgress || status.status === "completed") { - lastProgress = Math.max(lastProgress, nextProgress); - lastProgressAt = Date.now(); - } - - const patch: Partial = { - progress: status.progress, - resultUrl: status.resultUrl || current.resultUrl, - error: status.error || current.error, - }; - - if (status.status === "completed") { - patch.status = "completed"; - useGenerationStore.getState().updateTask(item.id, patch); - notifyCallbacks({ ...item, ...patch, status: "completed" }); - cleanupPoll(key); - } else if (status.status === "failed" || status.status === "cancelled") { - patch.status = "failed"; - patch.error = buildTaskFailureInfo(status.error).message; - useGenerationStore.getState().updateTask(item.id, patch); - notifyCallbacks({ ...item, ...patch, status: "failed" }); - cleanupPoll(key); - } else { - patch.status = "running"; - useGenerationStore.getState().updateTask(item.id, patch); - notifyCallbacks({ ...item, ...patch, status: "running" }); - } - } catch { - // Network errors during polling are retried until the lifecycle guard trips. - } - }, POLL_INTERVAL); - - activePollers.set(key, interval); +function updateTaskAndNotify(id: string, patch: Partial): GenerationQueueItem | null { + const current = useGenerationStore.getState().queue.find((i) => i.id === id); + if (!current) return null; + const next = { ...current, ...patch }; + useGenerationStore.getState().updateTask(id, patch); + notifyCallbacks(next); + return next; } -function cleanupPoll(key: string): void { - const interval = activePollers.get(key); - if (interval) { - clearInterval(interval); - activePollers.delete(key); - } +function isTerminalStatus(status: GenerationQueueItem["status"]): boolean { + return status === "completed" || status === "failed" || status === "cancelled"; +} + +function pollTask(item: GenerationQueueItem): void { + const key = `poll-${item.id}`; + if (activePollers.has(key) || !item.taskId) return; + + const kind = getQueueItemKind(item); + const abortRef = { current: false }; + activePollers.set(key, abortRef); + + const applyProgress = (event: TaskProgressEvent) => { + const current = useGenerationStore.getState().queue.find((i) => i.id === item.id); + if (!current || isTerminalStatus(current.status)) { + abortRef.current = true; + return; + } + + const patch: Partial = { + progress: Number(event.progress || 0), + resultUrl: event.resultUrl || current.resultUrl, + error: event.error || current.error, + }; + + if (event.status === "completed") { + patch.status = "completed"; + patch.progress = 100; + } else if (event.status === "failed" || event.status === "cancelled") { + patch.status = "failed"; + patch.error = buildTaskFailureInfo(event.error).message; + } else { + patch.status = "running"; + } + + updateTaskAndNotify(item.id, patch); + }; + + void waitForTask(item.taskId, { + kind, + model: getQueueItemModel(item), + startedAt: item.createdAt || Date.now(), + abortRef, + onProgress: applyProgress, + }) + .then((resultUrl) => { + if (abortRef.current) return; + const current = useGenerationStore.getState().queue.find((i) => i.id === item.id); + if (!current || isTerminalStatus(current.status)) return; + updateTaskAndNotify(item.id, { + status: "completed", + progress: 100, + resultUrl: resultUrl || current.resultUrl, + }); + }) + .catch((error) => { + if (abortRef.current) return; + const failure = buildTaskFailureInfo(error instanceof Error ? error.message : String(error)); + updateTaskAndNotify(item.id, { + status: "failed", + error: failure.message, + }); + }) + .finally(() => { + cleanupPoll(key, abortRef); + }); +} + +function cleanupPoll(key: string, abortRef: { current: boolean }): void { + if (activePollers.get(key) !== abortRef) return; + activePollers.delete(key); } export function startBackgroundPolling(): void { const tasks = useGenerationStore.getState().getRunningTasks(); - const attemptsMap = new Map(); tasks.forEach((task) => { if (task.taskId) { - if (!attemptsMap.has(task.id)) { - attemptsMap.set(task.id, { current: 0 }); - } - pollTask(task, attemptsMap.get(task.id)!); + pollTask(task); } }); } export function resumeTaskPolling(taskId: string, storeId: string): void { const task = useGenerationStore.getState().queue.find((i) => i.id === storeId); - if (task && task.status !== "completed" && task.status !== "failed") { - pollTask(task, { current: 0 }); + if (task && !isTerminalStatus(task.status)) { + pollTask({ ...task, taskId }); } } export function stopAllPolling(): void { - activePollers.forEach((interval) => clearInterval(interval)); + activePollers.forEach((abortRef) => { + abortRef.current = true; + }); activePollers.clear(); }