152 lines
4.8 KiB
TypeScript
152 lines
4.8 KiB
TypeScript
import { useGenerationStore, type GenerationQueueItem } from "../stores/useGenerationStore";
|
|
import { waitForTask, type TaskProgressEvent } from "../api/taskSubscription";
|
|
import { buildTaskFailureInfo } from "../utils/taskLifecycle";
|
|
|
|
type PollCallback = (item: GenerationQueueItem) => void;
|
|
|
|
const activePollers = new Map<string, { current: boolean }>();
|
|
const pollCallbacks = new Set<PollCallback>();
|
|
|
|
export function subscribeToTaskUpdates(callback: PollCallback): () => void {
|
|
pollCallbacks.add(callback);
|
|
return () => { pollCallbacks.delete(callback); };
|
|
}
|
|
|
|
function notifyCallbacks(item: GenerationQueueItem): void {
|
|
pollCallbacks.forEach((cb) => cb(item));
|
|
}
|
|
|
|
function getQueueItemKind(item: GenerationQueueItem): "image" | "video" | "text" {
|
|
if (item.type === "image") return "image";
|
|
if (item.type === "video" || item.type === "ecommerce-video") return "video";
|
|
return "text";
|
|
}
|
|
|
|
function getQueueItemModel(item: GenerationQueueItem): string | undefined {
|
|
return typeof item.params?.model === "string" ? item.params.model : undefined;
|
|
}
|
|
|
|
function updateTaskAndNotify(id: string, patch: Partial<GenerationQueueItem>): GenerationQueueItem | null {
|
|
const current = useGenerationStore.getState().queue.find((i) => i.id === id);
|
|
if (!current) return null;
|
|
const next = { ...current, ...patch };
|
|
useGenerationStore.getState().updateTask(id, patch);
|
|
notifyCallbacks(next);
|
|
return next;
|
|
}
|
|
|
|
function isTerminalStatus(status: GenerationQueueItem["status"]): boolean {
|
|
return status === "completed" || status === "failed" || status === "cancelled";
|
|
}
|
|
|
|
function pollTask(item: GenerationQueueItem): void {
|
|
const key = `poll-${item.id}`;
|
|
if (activePollers.has(key) || !item.taskId) return;
|
|
|
|
const kind = getQueueItemKind(item);
|
|
const abortRef = { current: false };
|
|
activePollers.set(key, abortRef);
|
|
|
|
const applyProgress = (event: TaskProgressEvent) => {
|
|
const current = useGenerationStore.getState().queue.find((i) => i.id === item.id);
|
|
if (!current || isTerminalStatus(current.status)) {
|
|
abortRef.current = true;
|
|
return;
|
|
}
|
|
|
|
const patch: Partial<GenerationQueueItem> = {
|
|
progress: Number(event.progress || 0),
|
|
resultUrl: event.resultUrl || current.resultUrl,
|
|
error: event.error || current.error,
|
|
};
|
|
|
|
if (event.status === "completed") {
|
|
patch.status = "completed";
|
|
patch.progress = 100;
|
|
} else if (event.status === "failed" || event.status === "cancelled") {
|
|
patch.status = "failed";
|
|
patch.error = buildTaskFailureInfo(event.error).message;
|
|
} else {
|
|
patch.status = "running";
|
|
}
|
|
|
|
updateTaskAndNotify(item.id, patch);
|
|
};
|
|
|
|
void waitForTask(item.taskId, {
|
|
kind,
|
|
model: getQueueItemModel(item),
|
|
startedAt: item.createdAt || Date.now(),
|
|
abortRef,
|
|
onProgress: applyProgress,
|
|
})
|
|
.then((resultUrl) => {
|
|
if (abortRef.current) return;
|
|
const current = useGenerationStore.getState().queue.find((i) => i.id === item.id);
|
|
if (!current || isTerminalStatus(current.status)) return;
|
|
updateTaskAndNotify(item.id, {
|
|
status: "completed",
|
|
progress: 100,
|
|
resultUrl: resultUrl || current.resultUrl,
|
|
});
|
|
})
|
|
.catch((error) => {
|
|
if (abortRef.current) return;
|
|
const failure = buildTaskFailureInfo(error instanceof Error ? error.message : String(error));
|
|
updateTaskAndNotify(item.id, {
|
|
status: "failed",
|
|
error: failure.message,
|
|
});
|
|
})
|
|
.finally(() => {
|
|
cleanupPoll(key, abortRef);
|
|
});
|
|
}
|
|
|
|
function cleanupPoll(key: string, abortRef: { current: boolean }): void {
|
|
if (activePollers.get(key) !== abortRef) return;
|
|
activePollers.delete(key);
|
|
}
|
|
|
|
export function startBackgroundPolling(): void {
|
|
const tasks = useGenerationStore.getState().getRunningTasks();
|
|
|
|
tasks.forEach((task) => {
|
|
if (task.taskId) {
|
|
pollTask(task);
|
|
}
|
|
});
|
|
}
|
|
|
|
export function resumeTaskPolling(taskId: string, storeId: string): void {
|
|
const task = useGenerationStore.getState().queue.find((i) => i.id === storeId);
|
|
if (task && !isTerminalStatus(task.status)) {
|
|
pollTask({ ...task, taskId });
|
|
}
|
|
}
|
|
|
|
export function stopAllPolling(): void {
|
|
activePollers.forEach((abortRef) => {
|
|
abortRef.current = true;
|
|
});
|
|
activePollers.clear();
|
|
}
|
|
|
|
export function recoverAndResumeTasks(): void {
|
|
const pendingTasks = useGenerationStore.getState().getRunningTasks();
|
|
if (!pendingTasks.length) return;
|
|
|
|
pendingTasks.forEach((task) => {
|
|
if (task.taskId) {
|
|
useGenerationStore.getState().updateTask(task.id, { status: "pending" });
|
|
} else {
|
|
useGenerationStore.getState().updateTask(task.id, {
|
|
status: "failed",
|
|
error: "页面刷新后任务没有服务端 ID,已释放本地占用,请重新提交。",
|
|
});
|
|
}
|
|
});
|
|
|
|
setTimeout(() => startBackgroundPolling(), 500);
|
|
}
|