From 2574dfe3d7912fedbd4894da462f8b32c10c5c6f Mon Sep 17 00:00:00 2001 From: Stringadmin Date: Mon, 8 Jun 2026 20:57:40 +0800 Subject: [PATCH] fix: harden generation task recovery --- src/api/aiGenerationClient.ts | 73 +++++++++++++++- src/api/generationConcurrency.ts | 4 + src/features/canvas/CanvasPage.tsx | 20 +++-- src/features/canvas/useCanvasGeneration.ts | 39 ++++++++- src/features/workbench/WorkbenchPage.tsx | 96 ++++++++++++++++++++-- 5 files changed, 212 insertions(+), 20 deletions(-) diff --git a/src/api/aiGenerationClient.ts b/src/api/aiGenerationClient.ts index 5e78c26..6fdfe29 100644 --- a/src/api/aiGenerationClient.ts +++ b/src/api/aiGenerationClient.ts @@ -4,6 +4,7 @@ import { isRecord, readJsonResponse, serverRequest, + isServerRequestError, throwResponseError, } from "./serverConnection"; import { isOptionalApiRouteMissing } from "./apiErrorUtils"; @@ -247,6 +248,46 @@ let taskHistoryRouteMissing = false; const TASK_SUBMIT_TIMEOUT_MS = 90_000; const TASK_STATUS_TIMEOUT_MS = 20_000; const NON_RETRYING_REQUEST = { maxRetries: 0 }; +const PENDING_CANCEL_TASKS_KEY = "omniai:pending-task-cancellations"; + +function readPendingCancelTaskIds(): string[] { + if (typeof window === "undefined") return []; + try { + const raw = window.localStorage.getItem(PENDING_CANCEL_TASKS_KEY); + const parsed = raw ? JSON.parse(raw) : []; + return Array.isArray(parsed) ? parsed.filter((id): id is string => typeof id === "string" && id.trim().length > 0) : []; + } catch { + return []; + } +} + +function writePendingCancelTaskIds(taskIds: string[]): void { + if (typeof window === "undefined") return; + try { + const uniqueIds = Array.from(new Set(taskIds.filter(Boolean))); + if (uniqueIds.length) { + window.localStorage.setItem(PENDING_CANCEL_TASKS_KEY, JSON.stringify(uniqueIds)); + } else { + window.localStorage.removeItem(PENDING_CANCEL_TASKS_KEY); + } + } catch { + // Pending cancellation recovery is best-effort. + } +} + +function markTaskCancelPending(taskId: string): void { + writePendingCancelTaskIds([...readPendingCancelTaskIds(), taskId]); +} + +function clearPendingTaskCancel(taskId: string): void { + writePendingCancelTaskIds(readPendingCancelTaskIds().filter((id) => id !== taskId)); +} + +function shouldRetryTaskCancel(error: unknown): boolean { + if (!isServerRequestError(error)) return true; + const status = error.status; + return status === 429 || status === undefined || status >= 500; +} export const aiGenerationClient = { async createImageTask(input: ImageGenInput): Promise { @@ -335,18 +376,48 @@ export const aiGenerationClient = { }, async cancelTask(taskId: string): Promise { + markTaskCancelPending(taskId); try { await serverRequest(`ai/tasks/${taskId}/cancel`, { method: "PATCH", maxRetries: NON_RETRYING_REQUEST.maxRetries, fallbackMessage: "Task cancel failed", }); + clearPendingTaskCancel(taskId); } catch (error) { - if (isOptionalApiRouteMissing(error)) return; + if (isOptionalApiRouteMissing(error) || !shouldRetryTaskCancel(error)) { + clearPendingTaskCancel(taskId); + return; + } throw error; } }, + cancelTaskOnUnload(taskId: string): void { + markTaskCancelPending(taskId); + const url = buildApiUrl(`ai/tasks/${encodeURIComponent(taskId)}/cancel`); + const headers = buildAuthHeaders(); + const body = JSON.stringify({ reason: "page_unload" }); + + try { + void fetch(url, { + method: "PATCH", + headers, + body, + credentials: "include", + keepalive: true, + }); + } catch { + // Page unload cancellation is best-effort. + } + }, + + flushPendingTaskCancellations(): void { + readPendingCancelTaskIds().forEach((taskId) => { + this.cancelTask(taskId).catch(() => {}); + }); + }, + async getTaskStatus(taskId: string): Promise { return serverRequest(`ai/tasks/${taskId}`, { timeoutMs: TASK_STATUS_TIMEOUT_MS, diff --git a/src/api/generationConcurrency.ts b/src/api/generationConcurrency.ts index a4e3b56..9b01996 100644 --- a/src/api/generationConcurrency.ts +++ b/src/api/generationConcurrency.ts @@ -21,6 +21,10 @@ function getEffectiveLimit(): number { return userMaxConcurrency ?? DEFAULT_MAX_ACTIVE_GENERATION_TASKS; } +export function getEffectiveGenerationLimit(): number { + return getEffectiveLimit(); +} + export function getGenerationUserKey(userId?: string | number | null): string { return userId === undefined || userId === null || userId === "" ? "anonymous" : String(userId); } diff --git a/src/features/canvas/CanvasPage.tsx b/src/features/canvas/CanvasPage.tsx index 50ee0b2..0e699ba 100644 --- a/src/features/canvas/CanvasPage.tsx +++ b/src/features/canvas/CanvasPage.tsx @@ -396,7 +396,6 @@ function CanvasPage({ const canvasUploadInputRef = useRef(null); const imageNodeInputRef = useRef(null); const canvasRef = useRef(null); - const videoGenerationInFlightRef = useRef(new Set()); const canvasReferenceUploadPromisesRef = useRef(new Map>()); const canvasDragCounterRef = useRef(0); const [isCanvasDragging, setIsCanvasDragging] = useState(false); @@ -417,7 +416,7 @@ function CanvasPage({ const { textGenerationState, imageGenerationState, videoGenerationState, generationToast, setGenerationToast, - imageGenerationInFlightRef, textGenerationInFlightRef, textGenerationAbortControllersRef, + imageGenerationInFlightRef, videoGenerationInFlightRef, textGenerationInFlightRef, textGenerationAbortControllersRef, canvasGenKeepaliveRestoredRef, setTextGenerationStatus, setImageGenerationStatus, setVideoGenerationStatus, restoreKeepaliveTasks, resetGenerationState, @@ -1887,13 +1886,14 @@ function CanvasPage({ setVideoGenerationStatus(nodeId, { status: "submitting", message: "正在提交视频生成", progress: 8 }); setGenerationToast("视频正在生成"); + let task: Awaited> | null = null; try { const referenceUrls = await resolveConnectedImageReferenceUrls("video", nodeId); if (videoNode.videoMode === "img2video" && referenceUrls.length === 0) { throw new Error("图生视频需要先连接至少一个可用的图片节点"); } let requestModel = resolveVideoRequestModel({ model, referenceUrls }); - const task = await onCreateTask({ + task = await onCreateTask({ title: videoNode.title || "视频节点生成", type: "video", prompt: prompt || "根据参考图片生成视频", @@ -1916,10 +1916,12 @@ function CanvasPage({ if (task.status === "completed" && !task.outputUrl) { throw new Error("视频生成任务已完成,但服务器没有返回结果地址,请稍后重试"); } + const taskId = task.id; + addCanvasGenKeepalive(taskId, nodeId, "video", projectId || ""); setVideoGenerationStatus(nodeId, { status: "running", message: "视频生成中", progress: Math.max(18, Number(task.progress || 0)) }); const outputUrl = task.outputUrl || - (await waitForImageTaskResult(task.id, (status) => { + (await waitForVideoTaskResult(taskId, (status) => { const progress = Math.max(18, Math.min(status.status === "completed" ? 100 : 96, Math.trunc(status.progress || 0))); const statusLabel = status.status === "pending" @@ -1932,11 +1934,12 @@ function CanvasPage({ setVideoGenerationStatus(nodeId, { status: "running", message: statusLabel, progress }); })); setVideoGenerationStatus(nodeId, { status: "success", message: "视频生成完成", progress: 100 }); + removeCanvasGenKeepalive(taskId); const immediateAssetRef = createCanvasAssetRefFromGeneratedResult({ url: outputUrl, mediaType: "video/mp4", resultType: "video", - taskId: task.id, + taskId, originalUrl: outputUrl, }); setVideoNodes((currentNodes) => @@ -1947,7 +1950,7 @@ function CanvasPage({ videoUrl: outputUrl, assetRef: immediateAssetRef, taskRef: { - taskId: task.id, + taskId, status: "completed", resultUrl: outputUrl, updatedAt: new Date().toISOString(), @@ -1961,7 +1964,7 @@ function CanvasPage({ url: outputUrl, mediaType: "video/mp4", resultType: "video", - taskId: task.id, + taskId, originalUrl: outputUrl, }); await delay(420); @@ -1974,7 +1977,7 @@ function CanvasPage({ videoUrl: assetRef.url, assetRef, taskRef: { - taskId: task.id, + taskId, status: "completed", resultUrl: assetRef.url, updatedAt: new Date().toISOString(), @@ -1991,6 +1994,7 @@ function CanvasPage({ }); } finally { videoGenerationInFlightRef.current.delete(nodeId); + if (task?.id) removeCanvasGenKeepalive(task.id); } }; diff --git a/src/features/canvas/useCanvasGeneration.ts b/src/features/canvas/useCanvasGeneration.ts index 61c2e85..8af13ba 100644 --- a/src/features/canvas/useCanvasGeneration.ts +++ b/src/features/canvas/useCanvasGeneration.ts @@ -6,6 +6,7 @@ import type { CanvasVideoGenerationState, CanvasVideoNode, } from "./canvasTypes"; +import { aiGenerationClient } from "../../api/aiGenerationClient"; import { createCanvasAssetRefFromGeneratedResult, persistCanvasGeneratedResultAsset } from "./canvasAssetPersistence"; import { waitForImageTaskResult, waitForVideoTaskResult } from "./canvasUtils"; @@ -41,6 +42,13 @@ export function removeCanvasGenKeepalive(taskId: string): void { saveCanvasGenKeepalive(loadCanvasGenKeepalive().filter((e) => e.taskId !== taskId)); } +export function cancelCanvasGenKeepaliveOnUnload(): void { + const entries = loadCanvasGenKeepalive(); + if (!entries.length) return; + entries.forEach((entry) => aiGenerationClient.cancelTaskOnUnload(entry.taskId)); + saveCanvasGenKeepalive([]); +} + export interface UseCanvasGenerationParams { setImageNodes: Dispatch>; setVideoNodes: Dispatch>; @@ -55,6 +63,7 @@ export function useCanvasGeneration(params: UseCanvasGenerationParams) { const [generationToast, setGenerationToast] = useState(null); const imageGenerationInFlightRef = useRef(new Set()); + const videoGenerationInFlightRef = useRef(new Set()); const textGenerationInFlightRef = useRef(new Set()); const textGenerationAbortControllersRef = useRef(new Map()); const canvasGenKeepaliveRestoredRef = useRef(false); @@ -125,7 +134,7 @@ export function useCanvasGeneration(params: UseCanvasGenerationParams) { imageGenerationInFlightRef.current.delete(entry.nodeId); }); } else if (entry.nodeKind === "video") { - imageGenerationInFlightRef.current.add(entry.nodeId); + videoGenerationInFlightRef.current.add(entry.nodeId); setVideoGenerationStatus(entry.nodeId, { status: "running", message: "正在恢复视频生成", progress: 20 }); void waitForVideoTaskResult(entry.taskId, (status) => { const progress = Math.max(18, Math.min(status.status === "completed" ? 100 : 96, Math.trunc(status.progress || 0))); @@ -154,7 +163,7 @@ export function useCanvasGeneration(params: UseCanvasGenerationParams) { removeCanvasGenKeepalive(entry.taskId); setVideoGenerationStatus(entry.nodeId, { status: "error", message: "视频生成失败" }); }).finally(() => { - imageGenerationInFlightRef.current.delete(entry.nodeId); + videoGenerationInFlightRef.current.delete(entry.nodeId); }); } } @@ -165,11 +174,36 @@ export function useCanvasGeneration(params: UseCanvasGenerationParams) { textGenerationAbortControllersRef.current.clear(); textGenerationInFlightRef.current.clear(); imageGenerationInFlightRef.current.clear(); + videoGenerationInFlightRef.current.clear(); setTextGenerationState({}); setImageGenerationState({}); setVideoGenerationState({}); }; + useEffect(() => { + const handlePageHide = () => { + cancelCanvasGenKeepaliveOnUnload(); + textGenerationAbortControllersRef.current.forEach((controller) => controller.abort()); + textGenerationAbortControllersRef.current.clear(); + textGenerationInFlightRef.current.clear(); + imageGenerationInFlightRef.current.clear(); + videoGenerationInFlightRef.current.clear(); + setTextGenerationState({}); + setImageGenerationState({}); + setVideoGenerationState({}); + }; + const handleOnline = () => { + aiGenerationClient.flushPendingTaskCancellations(); + }; + window.addEventListener("pagehide", handlePageHide); + window.addEventListener("online", handleOnline); + aiGenerationClient.flushPendingTaskCancellations(); + return () => { + window.removeEventListener("pagehide", handlePageHide); + window.removeEventListener("online", handleOnline); + }; + }, []); + return { textGenerationState, imageGenerationState, @@ -177,6 +211,7 @@ export function useCanvasGeneration(params: UseCanvasGenerationParams) { generationToast, setGenerationToast, imageGenerationInFlightRef, + videoGenerationInFlightRef, textGenerationInFlightRef, textGenerationAbortControllersRef, canvasGenKeepaliveRestoredRef, diff --git a/src/features/workbench/WorkbenchPage.tsx b/src/features/workbench/WorkbenchPage.tsx index 47b0e1c..e7d3d79 100644 --- a/src/features/workbench/WorkbenchPage.tsx +++ b/src/features/workbench/WorkbenchPage.tsx @@ -37,7 +37,7 @@ import { import "../../styles/pages/workbench.css"; import type { WebGenerationPreviewTask, WebUserSession } from "../../types"; import { aiGenerationClient } from "../../api/aiGenerationClient"; -import { claimGenerationSlot, getActiveGenerationTaskCount, getGenerationUserKey, releaseGenerationSlot } from "../../api/generationConcurrency"; +import { claimGenerationSlot, getActiveGenerationTaskCount, getEffectiveGenerationLimit, getGenerationUserKey, releaseGenerationSlot } from "../../api/generationConcurrency"; import { preUploadReference, resolvePreUploadedUrl } from "../../api/referenceUploadService"; import { assetClient } from "../../api/assetClient"; import { communityClient } from "../../api/communityClient"; @@ -988,6 +988,54 @@ function WorkbenchPage({ persistKeepaliveTasks(rest); }; + const releaseKeepaliveTaskLocally = useCallback((taskId: string, options?: { cancelServer?: boolean }) => { + const task = keepaliveTasksRef.current[taskId]; + taskAbortControllersRef.current.get(taskId)?.abort(); + taskAbortControllersRef.current.delete(taskId); + removeKeepaliveTask(taskId); + if (task && options?.cancelServer) { + aiGenerationClient.cancelTask(task.taskId).catch(() => {}); + } + syncActiveGenerationUi(); + }, [syncActiveGenerationUi]); + + const releaseKeepaliveTaskAfterNetworkLoss = useCallback((task: WorkbenchKeepaliveTask, progress: number) => { + const latestTask = { + ...task, + progress, + statusLabel: "网络中断,已释放提交按钮", + }; + void patchConversationMessage(task.conversationId, task.assistantMessageId, { + status: "failed", + taskProgress: Math.max(progress, 100), + taskStatusLabel: "网络中断", + body: "网络中断,当前任务已停止等待并释放提交按钮。请确认网络恢复后重新提交任务。", + }); + upsertKeepaliveTask(latestTask); + releaseKeepaliveTaskLocally(task.taskId, { cancelServer: true }); + if (activeConversationIdRef.current === task.conversationId) { + setIsGenerating(false); + setGenerationStatus("网络中断,已释放提交按钮"); + setGenerationProgress(0); + } + }, [patchConversationMessage, releaseKeepaliveTaskLocally]); + + const cancelActiveKeepaliveTasksOnPageExit = useCallback(() => { + const tasks = Object.values(keepaliveTasksRef.current); + if (!tasks.length) return; + tasks.forEach((task) => { + taskAbortControllersRef.current.get(task.taskId)?.abort(); + taskAbortControllersRef.current.delete(task.taskId); + releaseGenerationSlot(task.concurrencySlotId); + aiGenerationClient.cancelTaskOnUnload(task.taskId); + }); + keepaliveTasksRef.current = {}; + persistKeepaliveTasks({}); + setIsGenerating(false); + setGenerationStatus("已释放未完成任务"); + setGenerationProgress(0); + }, []); + const runKeepalivePoll = useCallback( (task: WorkbenchKeepaliveTask) => { if (taskAbortControllersRef.current.has(task.taskId)) return; @@ -1014,6 +1062,10 @@ function WorkbenchPage({ if (abortController.signal.aborted) return; if (attempt > 0) await sleep(3000); if (abortController.signal.aborted) return; + if (typeof navigator !== "undefined" && navigator.onLine === false) { + releaseKeepaliveTaskAfterNetworkLoss(task, lastKnownProgress); + return; + } let status; try { @@ -1027,7 +1079,8 @@ function WorkbenchPage({ taskProgress: 100, taskStatusLabel: "任务异常", }); - removeKeepaliveTask(task.taskId); + releaseKeepaliveTaskLocally(task.taskId, { cancelServer: true }); + onRefreshUsage?.(); return; } continue; @@ -1255,6 +1308,24 @@ function WorkbenchPage({ }; }, [runKeepalivePoll]); + useEffect(() => { + const handlePageHide = () => { + cancelActiveKeepaliveTasksOnPageExit(); + }; + const handleOnline = () => { + Object.values(keepaliveTasksRef.current).forEach((task) => runKeepalivePoll(task)); + syncActiveGenerationUi(); + }; + + window.addEventListener("pagehide", handlePageHide); + window.addEventListener("online", handleOnline); + aiGenerationClient.flushPendingTaskCancellations(); + return () => { + window.removeEventListener("pagehide", handlePageHide); + window.removeEventListener("online", handleOnline); + }; + }, [cancelActiveKeepaliveTasksOnPageExit, runKeepalivePoll, syncActiveGenerationUi]); + useEffect(() => { persistPromptHistory(promptHistory); }, [promptHistory]); @@ -1885,7 +1956,7 @@ function WorkbenchPage({ const trimmedPrompt = (promptOverride ?? inputValue).trim(); if (!trimmedPrompt) return; const userKey = getGenerationUserKey(session?.user.id); - if (activeMode !== "chat" && getActiveGenerationTaskCount(userKey) >= 3) return; + if (activeMode !== "chat" && getActiveGenerationTaskCount(userKey) >= getEffectiveGenerationLimit()) return; setReferencePreviewOpen(false); let conversationId = activeConversationIdRef.current ?? activeConversationId; @@ -2364,8 +2435,11 @@ function WorkbenchPage({ setProjectError("仅支持对视频结果进行超分"); return; } - if (getActiveGenerationTaskCount(getGenerationUserKey(session?.user.id)) >= 3) { - setProjectError(`当前已有 ${getActiveGenerationTaskCount(getGenerationUserKey(session?.user.id))} 个任务进行中(上限3个),请等待任一任务完成后再提交新任务`); + const userKey = getGenerationUserKey(session?.user.id); + const activeCount = getActiveGenerationTaskCount(userKey); + const limit = getEffectiveGenerationLimit(); + if (activeCount >= limit) { + setProjectError(`当前已有 ${activeCount} 个任务进行中(上限${limit}个),请等待任一任务完成后再提交新任务`); return; } if (!isAuthenticated) { @@ -2486,8 +2560,11 @@ function WorkbenchPage({ setProjectError("仅支持对图片结果进行超分"); return; } - if (getActiveGenerationTaskCount(getGenerationUserKey(session?.user.id)) >= 3) { - setProjectError(`当前已有 ${getActiveGenerationTaskCount(getGenerationUserKey(session?.user.id))} 个任务进行中(上限3个),请等待任一任务完成后再提交新任务`); + const userKey = getGenerationUserKey(session?.user.id); + const activeCount = getActiveGenerationTaskCount(userKey); + const limit = getEffectiveGenerationLimit(); + if (activeCount >= limit) { + setProjectError(`当前已有 ${activeCount} 个任务进行中(上限${limit}个),请等待任一任务完成后再提交新任务`); return; } if (!isAuthenticated) { @@ -2660,13 +2737,14 @@ function WorkbenchPage({ }; const activeGenerationCount = getActiveGenerationTaskCount(getGenerationUserKey(session?.user.id)); - const generationLimitReached = activeMode !== "chat" && activeGenerationCount >= 3; + const activeGenerationLimit = getEffectiveGenerationLimit(); + const generationLimitReached = activeMode !== "chat" && activeGenerationCount >= activeGenerationLimit; const promptIsEmpty = !inputValue.trim(); const sendDisabled = promptIsEmpty || generationLimitReached; const sendButtonTitle = promptIsEmpty ? "输入内容后可发送" : generationLimitReached - ? `当前已有 ${activeGenerationCount} 个任务进行中,请等待任一任务完成` + ? `当前已有 ${activeGenerationCount} 个任务进行中(上限 ${activeGenerationLimit} 个),请等待任一任务完成` : billingEstimate.title; const suggestedPrompts = [