Files
omniai-server/src/aiTaskWorker.js
T
2026-06-08 15:00:19 +08:00

1083 lines
36 KiB
JavaScript
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"use strict";
const crypto = require("node:crypto");
const { EventEmitter } = require("node:events");
const { pool } = require("./db");
const { refundTaskBillingOnFailure } = require("./billing");
const { putObject, isOssConfigured } = require("./ossClient");
const taskEvents = new EventEmitter();
taskEvents.setMaxListeners(200);
const activePollers = new Map();
const POLL_INTERVAL_MS = 3000;
const MAX_POLL_ATTEMPTS = 120;
const GRS_IMAGE_MAX_POLL_ATTEMPTS = Number(process.env.GRSAI_IMAGE_MAX_POLL_ATTEMPTS || 60);
const TASK_EVENT_CHANNEL = "generation_task_events";
const TASK_EVENT_ORIGIN = `${process.pid}-${crypto.randomUUID()}`;
const POLLER_OWNER_ID = `${process.pid}-${crypto.randomUUID()}`;
const POLLER_OWNER_STALE_MS = Number(process.env.TASK_POLLER_OWNER_STALE_MS || 20_000);
const POLLER_RECOVERY_INTERVAL_MS = Number(process.env.TASK_POLLER_RECOVERY_INTERVAL_MS || 30_000);
let taskEventListenerClient = null;
let taskEventListenerStarting = null;
let pollerStoreReady = null;
let pollerRecoveryTimer = null;
function normalizeTaskProgress(value) {
const numeric = Number(value);
if (!Number.isFinite(numeric)) return undefined;
return Math.max(0, Math.min(100, Math.round(numeric)));
}
function formatTaskEvent(row) {
return {
taskId: row.id,
status: row.status,
progress: row.progress,
resultUrl: row.result_url || null,
error: row.error || null,
};
}
function emitTaskEvent(event) {
if (!event?.taskId) return;
taskEvents.emit(`task:${event.taskId}`, event);
}
async function publishTaskEvent(event) {
if (!event?.taskId) return;
emitTaskEvent(event);
try {
await pool.query("SELECT pg_notify($1, $2)", [
TASK_EVENT_CHANNEL,
JSON.stringify({ origin: TASK_EVENT_ORIGIN, event }),
]);
} catch (err) {
console.error(`[aiTaskWorker] task event publish failed for task ${event.taskId}:`, err.message);
}
}
function serializeProviderConfig(providerConfig) {
if (!providerConfig || typeof providerConfig !== "object") return {};
const allowedKeys = [
"provider",
"transport",
"protocol",
"baseUrl",
"endpoint",
"resultEndpoint",
"model",
"requestedModel",
];
const result = {};
for (const key of allowedKeys) {
if (providerConfig[key] !== undefined) result[key] = providerConfig[key];
}
return result;
}
function parseProviderConfig(value) {
if (!value) return {};
if (typeof value === "object") return value;
try {
const parsed = JSON.parse(value);
return parsed && typeof parsed === "object" ? parsed : {};
} catch {
return {};
}
}
async function ensureTaskPollerStore() {
if (pollerStoreReady) return pollerStoreReady;
pollerStoreReady = pool.query(`
CREATE TABLE IF NOT EXISTS generation_task_pollers (
task_id INTEGER PRIMARY KEY REFERENCES generation_tasks(id) ON DELETE CASCADE,
provider_task_id TEXT NOT NULL,
task_type TEXT NOT NULL,
provider_config_json TEXT NOT NULL,
lease_token TEXT,
owner_id TEXT,
owner_heartbeat_at TIMESTAMPTZ,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
CREATE INDEX IF NOT EXISTS idx_generation_task_pollers_owner
ON generation_task_pollers(owner_heartbeat_at);
`).catch((err) => {
pollerStoreReady = null;
throw err;
});
return pollerStoreReady;
}
async function persistPollerState(taskDbId, { providerTaskId, type, providerConfig, leaseToken }) {
await ensureTaskPollerStore();
await pool.query(
`
INSERT INTO generation_task_pollers (
task_id, provider_task_id, task_type, provider_config_json, lease_token,
owner_id, owner_heartbeat_at, updated_at
)
VALUES ($1, $2, $3, $4, $5, $6, NOW(), NOW())
ON CONFLICT (task_id) DO UPDATE SET
provider_task_id = EXCLUDED.provider_task_id,
task_type = EXCLUDED.task_type,
provider_config_json = EXCLUDED.provider_config_json,
lease_token = EXCLUDED.lease_token,
owner_id = EXCLUDED.owner_id,
owner_heartbeat_at = NOW(),
updated_at = NOW()
`,
[
taskDbId,
providerTaskId,
type,
JSON.stringify(serializeProviderConfig(providerConfig)),
leaseToken || null,
POLLER_OWNER_ID,
],
);
}
async function refreshPollerHeartbeat(taskDbId) {
await ensureTaskPollerStore();
await pool.query(
"UPDATE generation_task_pollers SET owner_id = $1, owner_heartbeat_at = NOW(), updated_at = NOW() WHERE task_id = $2",
[POLLER_OWNER_ID, taskDbId],
);
}
async function clearPollerState(taskDbId) {
await ensureTaskPollerStore();
await pool.query("DELETE FROM generation_task_pollers WHERE task_id = $1", [taskDbId]);
}
async function getLeaseKey(leaseToken) {
if (!leaseToken) return null;
const { rows } = await pool.query(
`
SELECT k.api_key
FROM key_leases l
JOIN api_keys k ON k.id = l.key_id
WHERE l.lease_token = $1
AND l.released_at IS NULL
AND k.enabled = 1
LIMIT 1
`,
[leaseToken],
);
const apiKey = rows[0]?.api_key;
return apiKey === "pool-slot" ? "" : apiKey || null;
}
async function claimPoller(taskId) {
await ensureTaskPollerStore();
const staleInterval = `${Math.max(5, Math.ceil(POLLER_OWNER_STALE_MS / 1000))} seconds`;
const { rows } = await pool.query(
`
UPDATE generation_task_pollers
SET owner_id = $1, owner_heartbeat_at = NOW(), updated_at = NOW()
WHERE task_id = $2
AND (
owner_heartbeat_at IS NULL
OR owner_heartbeat_at < NOW() - ($3::text)::interval
)
RETURNING *
`,
[POLLER_OWNER_ID, taskId, staleInterval],
);
return rows[0] || null;
}
async function createTaskLifecycleNotification(task) {
if (!task || !task.user_id || !task.id) return;
const isCompleted = task.status === "completed";
const isFailed = task.status === "failed";
if (!isCompleted && !isFailed) return;
const typeLabel = task.type === "video" ? "视频" : "图像";
const noticeType = isCompleted ? "task_completed" : "task_failed";
const title = isCompleted ? `${typeLabel}生成已完成` : `${typeLabel}生成失败`;
const description = isCompleted
? "生成结果已同步到任务历史,可继续编辑或保存到资产库。"
: String(task.error || "任务处理失败,请稍后重试。").slice(0, 500);
await pool.query(
`
INSERT INTO web_notifications (
user_id, type, title, description, target_type, target_id, metadata_json
)
SELECT $1::integer, $2::varchar, $3::varchar, $4::text, 'generation_task', $5::varchar, $6::text
WHERE NOT EXISTS (
SELECT 1
FROM web_notifications
WHERE user_id = $1
AND type = $2::varchar
AND target_type = 'generation_task'
AND target_id = $5::varchar
)
`,
[
task.user_id,
noticeType,
title,
description,
String(task.id),
JSON.stringify({ taskType: task.type, resultUrl: task.result_url || null }),
],
);
}
async function updateTaskInDb(taskId, updates) {
const nextUpdates = { ...updates };
const fields = [];
const values = [];
let idx = 1;
if (nextUpdates.status !== undefined) { fields.push(`status = $${idx++}`); values.push(nextUpdates.status); }
if (nextUpdates.progress !== undefined) {
const progress = normalizeTaskProgress(nextUpdates.progress);
if (progress !== undefined) { fields.push(`progress = $${idx++}`); values.push(progress); }
}
if (nextUpdates.resultUrl !== undefined) { fields.push(`result_url = $${idx++}`); values.push(nextUpdates.resultUrl); }
if (nextUpdates.error !== undefined) { fields.push(`error = $${idx++}`); values.push(nextUpdates.error); }
if (nextUpdates.providerTaskId !== undefined) { fields.push(`provider_task_id = $${idx++}`); values.push(nextUpdates.providerTaskId); }
if (nextUpdates.status === "completed" || nextUpdates.status === "failed") {
fields.push("completed_at = NOW()");
}
fields.push("updated_at = NOW()");
if (fields.length === 0) return;
values.push(taskId);
const { rows } = await pool.query(
`UPDATE generation_tasks SET ${fields.join(", ")} WHERE id = $${idx} RETURNING *`,
values,
);
let updatedTask = rows[0];
if (updatedTask) {
await publishTaskEvent(formatTaskEvent(updatedTask));
}
if (nextUpdates.status === "completed" && updatedTask?.result_url) {
persistTaskResultUrlToOssInBackground(updatedTask);
}
if (nextUpdates.status === "completed" || nextUpdates.status === "failed") {
await createTaskLifecycleNotification(updatedTask).catch((err) => {
console.error(`[aiTaskWorker] notification error for task ${taskId}:`, err.message);
});
}
if (nextUpdates.status === "failed") {
await refundTaskBillingOnFailure(taskId).catch((err) => {
console.error(`[aiTaskWorker] refund error for task ${taskId}:`, err.message);
});
}
}
function persistTaskResultUrlToOssInBackground(task) {
if (!task?.id || !task?.result_url) return;
Promise.resolve()
.then(async () => {
const durableUrl = await persistResultUrlToOss(task);
if (!durableUrl || durableUrl === task.result_url) return;
await pool.query(
"UPDATE generation_tasks SET result_url = $1, updated_at = NOW() WHERE id = $2 AND result_url = $3",
[durableUrl, task.id, task.result_url],
);
console.info(`[aiTaskWorker] task ${task.id} result persisted to OSS after completion`);
})
.catch((error) => {
console.warn(`[aiTaskWorker] background result persistence failed for task ${task.id}:`, error.message);
});
}
function asObject(value) {
return value && typeof value === "object" && !Array.isArray(value) ? value : undefined;
}
function asString(value) {
return typeof value === "string" && value.trim() ? value.trim() : undefined;
}
function firstString(...values) {
for (const value of values) {
const stringValue = asString(value);
if (stringValue) return stringValue;
}
return undefined;
}
function mediaExtensionFromContentType(contentType, fallbackUrl, taskType) {
const mime = String(contentType || "").split(";")[0].trim().toLowerCase();
if (mime === "image/jpeg") return "jpg";
if (mime === "image/png") return "png";
if (mime === "image/webp") return "webp";
if (mime === "image/gif") return "gif";
if (mime === "video/webm") return "webm";
if (mime === "video/quicktime") return "mov";
if (mime === "video/mp4") return "mp4";
try {
const matched = new URL(fallbackUrl).pathname.match(/\.([a-z0-9]{2,5})$/i);
if (matched?.[1]) return matched[1].toLowerCase();
} catch {
// Keep the type fallback below.
}
return taskType === "video" ? "mp4" : "png";
}
function isErrorDocumentContentType(contentType) {
return /(?:application|text)\/(?:json|xml|html|plain)|\+xml/i.test(String(contentType || ""));
}
function isOwnPersistedResultUrl(value) {
return /\/users\/[^/]+\/generation-results\//i.test(String(value || ""));
}
async function persistIncomingResultUrl(taskId, resultUrl) {
const normalizedUrl = String(resultUrl || "").trim();
if (!/^https?:\/\//i.test(normalizedUrl) || isOwnPersistedResultUrl(normalizedUrl) || !isOssConfigured()) {
return null;
}
try {
const { rows } = await pool.query(
"SELECT id, user_id, type, result_url FROM generation_tasks WHERE id = $1",
[taskId],
);
const task = rows[0];
if (!task?.user_id || !task?.type) return null;
return persistResultUrlToOss({ ...task, result_url: normalizedUrl });
} catch (error) {
console.warn(`[aiTaskWorker] result pre-persistence skipped for task ${taskId}:`, error.message);
return null;
}
}
async function persistResultUrlToOss(task) {
const resultUrl = String(task?.result_url || "").trim();
if (!/^https?:\/\//i.test(resultUrl) || isOwnPersistedResultUrl(resultUrl) || !isOssConfigured()) {
return null;
}
try {
const response = await fetch(resultUrl, { method: "GET" });
if (!response.ok) {
throw new Error(`result fetch returned ${response.status}`);
}
const contentType = response.headers.get("content-type") || (task.type === "video" ? "video/mp4" : "image/png");
if (isErrorDocumentContentType(contentType)) {
const text = await response.text().catch(() => "");
throw new Error(`result fetch returned error document: ${text.slice(0, 120)}`);
}
const buffer = Buffer.from(await response.arrayBuffer());
if (!buffer.length) {
throw new Error("result fetch returned empty content");
}
const safeUserId = String(task.user_id).replace(/[^a-zA-Z0-9_-]/g, "");
const extension = mediaExtensionFromContentType(contentType, resultUrl, task.type);
const objectKey = `users/${safeUserId}/generation-results/${task.id}-${Date.now()}-${crypto.randomUUID()}.${extension}`;
const uploaded = await putObject(objectKey, buffer, contentType, { "x-oss-object-acl": "public-read" });
return uploaded.url;
} catch (error) {
console.warn(`[aiTaskWorker] result persistence skipped for task ${task?.id}:`, error.message);
return null;
}
}
function normalizeImageResultValue(value) {
const stringValue = firstString(value);
if (!stringValue) return undefined;
if (/^(https?:)?\/\//i.test(stringValue) || /^data:image\//i.test(stringValue)) return stringValue;
const markdownImageMatch = stringValue.match(/!\[[^\]]*]\((https?:\/\/[^)\s]+)\)/i);
if (markdownImageMatch?.[1]) return markdownImageMatch[1];
const urlMatch = stringValue.match(/https?:\/\/[^\s"'<>)]+/i);
if (urlMatch?.[0]) return urlMatch[0];
try {
const parsed = JSON.parse(stringValue);
const parsedResult = firstImageResult(parsed);
if (parsedResult) return parsedResult;
} catch {}
if (/^[A-Za-z0-9+/]+={0,2}$/.test(stringValue) && stringValue.length > 128) {
return `data:image/png;base64,${stringValue}`;
}
return undefined;
}
function firstImageResult(...values) {
for (const value of values) {
if (Array.isArray(value)) {
for (const item of value) {
const result = firstImageResult(item);
if (result) return result;
}
continue;
}
if (value && typeof value === "object") {
const result = firstImageResult(
value.url,
value.image_url,
value.imageUrl,
value.result_url,
value.resultUrl,
value.output_url,
value.outputUrl,
value.b64_image,
value.b64_json,
value.base64,
value.image,
value.content,
value.text,
value.message,
);
if (result) return result;
continue;
}
const result = normalizeImageResultValue(value);
if (result) return result;
}
return undefined;
}
function pickDeep(value, keys) {
if (!value || typeof value !== "object") return undefined;
const obj = value;
for (const key of keys) {
const direct = firstString(obj[key]);
if (direct) return direct;
}
for (const child of Object.values(obj)) {
if (child && typeof child === "object") {
const nested = pickDeep(child, keys);
if (nested) return nested;
}
}
return undefined;
}
function normalizeStatus(value) {
return String(value || "").trim().toLowerCase();
}
function isCompletedStatus(status) {
return ["completed", "complete", "succeeded", "success", "succeed", "done", "finished", "successed"].includes(status);
}
function isFailedStatus(status) {
return ["failed", "failure", "fail", "canceled", "cancelled", "expired", "error", "violation"].includes(status);
}
function extractProviderTaskId(json) {
const data = asObject(json?.data) || json;
const output = asObject(json?.output) || asObject(data?.output);
return firstString(
output?.task_id,
output?.taskId,
data?.task_id,
data?.taskId,
data?.id,
json?.task_id,
json?.taskId,
json?.id,
);
}
function extractImageUrl(json) {
const rawData = json?.data;
const data = asObject(rawData) || json;
const rawOutput = json?.output ?? data?.output;
const output = asObject(rawOutput) || asObject(data?.output);
const rawResult = data?.result ?? json?.result ?? output?.result;
const result = asObject(rawResult);
const choices =
(Array.isArray(output?.choices) && output.choices) ||
(Array.isArray(data?.choices) && data.choices) ||
(Array.isArray(json?.choices) && json.choices) ||
[];
const firstChoice = asObject(choices[0]);
const message = asObject(firstChoice?.message);
const content = Array.isArray(message?.content) ? message.content : [];
const firstContent = asObject(content[0]);
const outputResults = Array.isArray(output?.results) ? output.results : [];
const topLevelResults = Array.isArray(json?.results) ? json.results : [];
const dataResults = Array.isArray(data?.results) ? data.results : [];
const resultResults = Array.isArray(result?.results) ? result.results : [];
const dataImages = Array.isArray(data?.images) ? data.images : [];
const dataImageUrls = Array.isArray(data?.image_urls) ? data.image_urls : [];
const dataUrls = Array.isArray(data?.urls) ? data.urls : [];
const outputImages = Array.isArray(output?.images) ? output.images : [];
const outputImageUrls = Array.isArray(output?.image_urls) ? output.image_urls : [];
const resultImages = Array.isArray(result?.images) ? result.images : [];
const resultUrls = Array.isArray(result?.urls) ? result.urls : [];
const candidates = [
...topLevelResults,
...dataResults,
...outputResults,
...resultResults,
...dataImages,
...dataImageUrls,
...dataUrls,
...outputImages,
...outputImageUrls,
...resultImages,
...resultUrls,
];
return firstImageResult(
rawData,
rawOutput,
rawResult,
firstContent?.image,
firstContent?.image_url,
firstContent?.image_url?.url,
message?.content,
firstChoice?.delta?.content,
candidates,
data?.image_url,
data?.imageUrl,
data?.result_url,
data?.resultUrl,
data?.output_url,
data?.outputUrl,
output?.image_url,
output?.imageUrl,
output?.result_url,
output?.resultUrl,
output?.output_url,
output?.outputUrl,
result?.image_url,
result?.imageUrl,
result?.result_url,
result?.resultUrl,
result?.output_url,
result?.outputUrl,
pickDeep(json, ["image", "image_url", "imageUrl", "result_url", "resultUrl", "output_url", "outputUrl", "url", "b64_image", "b64_json", "base64"]),
);
}
function extractGeminiImageUrl(json) {
// Gemini response: candidates[].content.parts[].inlineData (base64) or text (URL)
const candidates = Array.isArray(json?.candidates) ? json.candidates : [];
for (const candidate of candidates) {
const parts = Array.isArray(candidate?.content?.parts) ? candidate.content.parts : [];
for (const part of parts) {
const inlineData = part?.inlineData;
if (inlineData?.data) {
const mimeType = inlineData.mimeType || "image/png";
return `data:${mimeType};base64,${inlineData.data}`;
}
}
}
// Also check for direct URL in candidate text
for (const candidate of candidates) {
const parts = Array.isArray(candidate?.content?.parts) ? candidate.content.parts : [];
for (const part of parts) {
if (part?.text && /^https?:\/\/.+\.(png|jpg|jpeg|webp|gif)/i.test(part.text)) {
return part.text;
}
}
}
return null;
}
function extractVideoUrl(json) {
const data = asObject(json?.data) || json;
const output = asObject(json?.output) || asObject(data?.output);
const result = asObject(data?.task_result) || asObject(json?.task_result) || asObject(output?.task_result);
const videos = Array.isArray(result?.videos) ? result.videos : [];
const firstVideo = asObject(videos[0]);
return firstString(
output?.video_url,
output?.output_video_url,
output?.outputVideoUrl,
output?.watermark_video_url,
data?.video_url,
data?.videoUrl,
data?.output_video_url,
data?.outputVideoUrl,
data?.url,
data?.result?.url,
firstVideo?.url,
firstVideo?.video_url,
pickDeep(json, ["video_url", "output_video_url", "outputVideoUrl", "watermark_video_url", "videoUrl", "download_url", "downloadUrl", "content_url", "contentUrl", "url"]),
);
}
function extractErrorMessage(json, fallback) {
const data = asObject(json?.data) || {};
const output = asObject(json?.output) || {};
const error = asObject(json?.error) || {};
return firstString(
output.message,
output.code,
data.task_status_msg,
data.failure_reason,
data.message,
error.message,
error.error,
json?.message,
json?.error,
) || fallback;
}
async function fetchJson(url, headers) {
const res = await fetch(url, { method: "GET", headers });
if (!res.ok) return { ok: false, json: null };
return { ok: true, json: await res.json() };
}
async function pollGrsaiImage(_taskId, providerTaskId, apiKey, baseUrl, resultEndpoint) {
const url = `${baseUrl}${resultEndpoint}?id=${encodeURIComponent(providerTaskId)}`;
const { ok, json } = await fetchJson(url, {
Authorization: `Bearer ${apiKey}`,
Accept: "application/json",
});
if (!ok) {
console.warn(`[grsai-poll] task ${_taskId} fetch not ok, url=${url}`);
return { status: "running", progress: 50 };
}
const data = asObject(json?.data) || json;
const output = asObject(data?.output) || asObject(json?.output) || {};
const status = normalizeStatus(
output.task_status ||
output.status ||
data.task_status ||
data.status ||
data.state ||
json?.task_status ||
json?.status ||
json?.state,
);
const resultUrl = extractImageUrl(json);
console.info(`[grsai-poll] task ${_taskId} status=${status} resultUrl=${resultUrl ? "yes" : "no"} raw=${JSON.stringify(json).slice(0, 300)}`);
if (resultUrl) {
return { status: "completed", progress: 100, resultUrl };
}
if (isCompletedStatus(status)) {
const completedUrl = extractImageUrl(json);
if (!completedUrl) return { status: "failed", error: "Image generation completed without a result url" };
return { status: "completed", progress: 100, resultUrl: completedUrl };
}
if (isFailedStatus(status)) {
return { status: "failed", error: extractErrorMessage(json, "Image generation failed") };
}
return { status: "running", progress: Math.min(90, 30 + Math.random() * 40) };
}
async function pollDashscopeImage(_taskId, providerTaskId, apiKey) {
const url = `https://dashscope.aliyuncs.com/api/v1/tasks/${encodeURIComponent(providerTaskId)}`;
const { ok, json } = await fetchJson(url, {
Authorization: `Bearer ${apiKey}`,
Accept: "application/json",
});
if (!ok) return { status: "running", progress: 50 };
const output = asObject(json?.output) || {};
const status = normalizeStatus(output.task_status || json?.task_status);
const resultUrl = extractImageUrl(json);
if (isCompletedStatus(status)) {
if (!resultUrl) return { status: "failed", error: "DashScope image generation completed without a result url" };
return { status: "completed", progress: 100, resultUrl };
}
if (isFailedStatus(status)) {
return { status: "failed", error: extractErrorMessage(json, "DashScope image generation failed") };
}
return { status: "running", progress: Math.min(90, 30 + Math.random() * 40) };
}
function base64Url(input) {
return Buffer.from(input)
.toString("base64")
.replace(/\+/g, "-")
.replace(/\//g, "_")
.replace(/=+$/, "");
}
function parseKlingCredential(apiKey) {
const raw = String(apiKey || "");
const colonIndex = raw.indexOf(":");
if (colonIndex <= 0) return null;
const accessKey = raw.slice(0, colonIndex).trim();
const secretKey = raw.slice(colonIndex + 1).trim();
return accessKey && secretKey ? { accessKey, secretKey } : null;
}
function createKlingJwt(accessKey, secretKey) {
const header = { alg: "HS256", typ: "JWT" };
const now = Math.floor(Date.now() / 1000);
const payload = { iss: accessKey, exp: now + 1800, nbf: now - 5 };
const unsigned = `${base64Url(JSON.stringify(header))}.${base64Url(JSON.stringify(payload))}`;
const signature = crypto.createHmac("sha256", secretKey).update(unsigned).digest("base64")
.replace(/\+/g, "-")
.replace(/\//g, "_")
.replace(/=+$/, "");
return `${unsigned}.${signature}`;
}
function getPollRequest(providerTaskId, apiKey, providerConfig) {
const protocol = providerConfig.protocol;
const baseUrl = providerConfig.baseUrl || "";
if (
protocol === "wan-i2v" ||
protocol === "wan-s2v" ||
protocol === "kling-dashscope" ||
String(protocol || "").startsWith("happyhorse-")
) {
return {
url: `${baseUrl}/api/v1/tasks/${encodeURIComponent(providerTaskId)}`,
headers: { Authorization: `Bearer ${apiKey}`, Accept: "application/json" },
};
}
if (protocol === "kling-omni") {
const credential = parseKlingCredential(apiKey);
return {
url: `${baseUrl}/v1/videos/omni-video/${encodeURIComponent(providerTaskId)}`,
headers: {
Authorization: `Bearer ${credential ? createKlingJwt(credential.accessKey, credential.secretKey) : apiKey}`,
Accept: "application/json",
},
};
}
return {
url: `${baseUrl}${providerConfig.endpoint}/${encodeURIComponent(providerTaskId)}`,
headers: { Authorization: `Bearer ${apiKey}`, Accept: "application/json" },
};
}
async function pollVideoTask(_taskId, providerTaskId, apiKey, providerConfig) {
const { url, headers } = getPollRequest(providerTaskId, apiKey, providerConfig);
const { ok, json } = await fetchJson(url, headers);
if (!ok) return { status: "running", progress: 50 };
const data = asObject(json?.data) || json;
const output = asObject(json?.output) || {};
const status = normalizeStatus(
output.task_status ||
data.task_status ||
data.status ||
json?.task_status ||
json?.status,
);
const resultUrl = extractVideoUrl(json);
if (isCompletedStatus(status) || resultUrl) {
return { status: "completed", progress: 100, resultUrl: resultUrl || null };
}
if (isFailedStatus(status)) {
return { status: "failed", error: extractErrorMessage(json, "Video generation failed") };
}
const progress = Number(data.progress || output.progress);
return { status: "running", progress: Number.isFinite(progress) ? Math.min(95, progress) : Math.min(90, 30 + Math.random() * 30) };
}
function getMaxPollAttempts(type, providerConfig) {
if (type === "image" && providerConfig?.transport === "grsai-image") {
return Number.isFinite(GRS_IMAGE_MAX_POLL_ATTEMPTS) && GRS_IMAGE_MAX_POLL_ATTEMPTS > 0
? Math.trunc(GRS_IMAGE_MAX_POLL_ATTEMPTS)
: 40;
}
if (type === "video") return 400;
return MAX_POLL_ATTEMPTS;
}
function startPolling(taskDbId, { providerTaskId, apiKey, type, providerConfig, leaseToken, keyManager, onTaskFailed, skipPersist = false }) {
if (activePollers.has(taskDbId)) return;
if (!skipPersist) {
persistPollerState(taskDbId, { providerTaskId, type, providerConfig, leaseToken }).catch((err) => {
console.error(`[aiTaskWorker] failed to persist poller state for task ${taskDbId}:`, err.message);
});
}
let attempts = 0;
const maxPollAttempts = getMaxPollAttempts(type, providerConfig);
const interval = setInterval(async () => {
attempts++;
if (attempts > maxPollAttempts) {
clearInterval(interval);
activePollers.delete(taskDbId);
if (leaseToken && keyManager) await keyManager.releaseKey(leaseToken).catch(() => {});
if (typeof onTaskFailed === "function") {
const handled = await onTaskFailed("Task timed out").catch((fallbackErr) => {
console.error(`[aiTaskWorker] fallback error for task ${taskDbId}:`, fallbackErr.message);
return false;
});
if (handled) return;
}
await updateTaskInDb(taskDbId, { status: "failed", error: "Task timed out" });
await clearPollerState(taskDbId).catch(() => {});
return;
}
try {
// Check if task was cancelled by user
const { rows: [taskRow] } = await pool.query("SELECT status FROM generation_tasks WHERE id = $1", [taskDbId]);
if (!taskRow || taskRow.status === "cancelled") {
clearInterval(interval);
activePollers.delete(taskDbId);
await clearPollerState(taskDbId).catch(() => {});
if (leaseToken && keyManager) await keyManager.releaseKey(leaseToken).catch(() => {});
return;
}
await refreshPollerHeartbeat(taskDbId).catch(() => {});
let result;
if (type === "image") {
if (providerConfig.transport === "dashscope-image") {
result = await pollDashscopeImage(taskDbId, providerTaskId, apiKey);
} else {
result = await pollGrsaiImage(taskDbId, providerTaskId, apiKey, providerConfig.baseUrl, providerConfig.resultEndpoint || "/result");
}
} else {
result = await pollVideoTask(taskDbId, providerTaskId, apiKey, providerConfig);
}
if (result.status === "completed" || result.status === "failed") {
clearInterval(interval);
activePollers.delete(taskDbId);
if (leaseToken && keyManager) await keyManager.releaseKey(leaseToken).catch(() => {});
if (result.status === "failed" && typeof onTaskFailed === "function") {
const handled = await onTaskFailed(result.error || "Task failed").catch((fallbackErr) => {
console.error(`[aiTaskWorker] fallback error for task ${taskDbId}:`, fallbackErr.message);
return false;
});
if (handled) return;
}
}
await updateTaskInDb(taskDbId, result);
if (result.status === "completed" || result.status === "failed") {
await clearPollerState(taskDbId).catch(() => {});
}
} catch (err) {
console.error(`[aiTaskWorker] poll error for task ${taskDbId}:`, err.message);
}
}, POLL_INTERVAL_MS);
activePollers.set(taskDbId, { interval, leaseToken });
}
function stopPolling(taskDbId) {
const poller = activePollers.get(taskDbId);
if (poller) {
clearInterval(poller.interval);
activePollers.delete(taskDbId);
}
clearPollerState(taskDbId).catch(() => {});
}
function getActiveCount() {
return activePollers.size;
}
async function recoverRunnablePollers() {
await ensureTaskPollerStore();
const staleInterval = `${Math.max(5, Math.ceil(POLLER_OWNER_STALE_MS / 1000))} seconds`;
const { rows } = await pool.query(
`
SELECT p.task_id
FROM generation_task_pollers p
JOIN generation_tasks t ON t.id = p.task_id
WHERE t.status IN ('pending', 'running')
AND (
p.owner_heartbeat_at IS NULL
OR p.owner_heartbeat_at < NOW() - ($1::text)::interval
)
ORDER BY p.owner_heartbeat_at NULLS FIRST, p.updated_at ASC
LIMIT 20
`,
[staleInterval],
);
for (const row of rows) {
const taskId = row.task_id;
if (activePollers.has(taskId)) continue;
const poller = await claimPoller(taskId);
if (!poller || activePollers.has(taskId)) continue;
const apiKey = await getLeaseKey(poller.lease_token);
if (apiKey == null) {
console.warn(`[aiTaskWorker] cannot recover task ${taskId}: active lease not found`);
continue;
}
console.info(`[aiTaskWorker] recovering poller for task ${taskId}`);
startPolling(taskId, {
providerTaskId: poller.provider_task_id,
apiKey,
type: poller.task_type,
providerConfig: parseProviderConfig(poller.provider_config_json),
leaseToken: poller.lease_token,
keyManager: require("./keyManager"),
skipPersist: true,
});
}
}
// --- Periodic stale task cleanup ---
// Runs every 5 minutes, marks tasks stuck in 'pending'/'running' for too long as 'failed'.
// This catches cases where the worker crashed, the provider API never responded,
// or the cancel request failed silently on the client side.
const STALE_TASK_CLEANUP_INTERVAL_MS = 5 * 60 * 1000;
let staleTaskCleanupTimer = null;
async function runStaleTaskCleanup() {
try {
const { rows } = await pool.query(
`UPDATE generation_tasks
SET status = 'failed', error = '任务超时自动释放', updated_at = NOW()
WHERE status IN ('pending', 'running')
AND GREATEST(updated_at, COALESCE(last_poll_at, created_at)) < NOW() - INTERVAL '10 minutes'
RETURNING id`,
);
for (const row of rows) {
await publishTaskEvent({
taskId: row.id,
status: "failed",
progress: null,
resultUrl: null,
error: "任务超时自动释放",
});
// Also stop any active poller for this task
const poller = activePollers.get(row.id);
if (poller) {
clearInterval(poller.interval);
activePollers.delete(row.id);
}
await clearPollerState(row.id).catch(() => {});
}
if (rows.length > 0) {
console.log(`[aiTaskWorker] Cleaned up ${rows.length} stale task(s)`);
}
} catch (err) {
console.error("[aiTaskWorker] Stale task cleanup failed:", err.message);
}
}
async function startTaskEventListener() {
if (taskEventListenerClient) return;
if (taskEventListenerStarting) return taskEventListenerStarting;
taskEventListenerStarting = (async () => {
const client = await pool.connect();
let released = false;
const releaseClient = () => {
if (released) return;
released = true;
taskEventListenerClient = null;
try {
client.release();
} catch {}
};
client.on("notification", (message) => {
if (message.channel !== TASK_EVENT_CHANNEL || !message.payload) return;
try {
const payload = JSON.parse(message.payload);
if (payload?.origin === TASK_EVENT_ORIGIN) return;
emitTaskEvent(payload?.event || payload);
} catch (err) {
console.error("[aiTaskWorker] task event notification parse failed:", err.message);
}
});
client.on("error", (err) => {
console.error("[aiTaskWorker] task event listener error:", err.message);
releaseClient();
setTimeout(() => {
startTaskEventListener().catch((restartErr) => {
console.error("[aiTaskWorker] task event listener restart failed:", restartErr.message);
});
}, 5000).unref?.();
});
await client.query(`LISTEN ${TASK_EVENT_CHANNEL}`);
taskEventListenerClient = client;
console.log(`[aiTaskWorker] listening for task events on ${TASK_EVENT_CHANNEL}`);
})();
try {
await taskEventListenerStarting;
} finally {
taskEventListenerStarting = null;
}
}
async function stopTaskEventListener() {
const client = taskEventListenerClient;
taskEventListenerClient = null;
if (!client) return;
try {
await client.query(`UNLISTEN ${TASK_EVENT_CHANNEL}`);
} catch {}
client.release();
}
function startStaleTaskCleanup() {
if (staleTaskCleanupTimer) return;
staleTaskCleanupTimer = setInterval(runStaleTaskCleanup, STALE_TASK_CLEANUP_INTERVAL_MS);
// Run once shortly after startup
setTimeout(runStaleTaskCleanup, 10_000);
}
function stopStaleTaskCleanup() {
if (staleTaskCleanupTimer) {
clearInterval(staleTaskCleanupTimer);
staleTaskCleanupTimer = null;
}
}
function startPollerRecovery() {
if (pollerRecoveryTimer) return;
ensureTaskPollerStore()
.then(() => recoverRunnablePollers())
.catch((err) => console.error("[aiTaskWorker] initial poller recovery failed:", err.message));
pollerRecoveryTimer = setInterval(() => {
recoverRunnablePollers().catch((err) => {
console.error("[aiTaskWorker] poller recovery failed:", err.message);
});
}, POLLER_RECOVERY_INTERVAL_MS);
}
function stopPollerRecovery() {
if (pollerRecoveryTimer) {
clearInterval(pollerRecoveryTimer);
pollerRecoveryTimer = null;
}
}
module.exports = {
startPolling,
stopPolling,
updateTaskInDb,
getActiveCount,
extractProviderTaskId,
extractImageUrl,
extractGeminiImageUrl,
extractVideoUrl,
parseKlingCredential,
createKlingJwt,
taskEvents,
startTaskEventListener,
stopTaskEventListener,
startPollerRecovery,
stopPollerRecovery,
startStaleTaskCleanup,
stopStaleTaskCleanup,
};