diff --git a/package.json b/package.json index 157d0ef..040011d 100644 --- a/package.json +++ b/package.json @@ -18,7 +18,8 @@ "test:enterprise-video-pricing": "node scripts/enterpriseVideoPricingContract.test.js", "test:key-manager": "node scripts/keyManagerReleaseContract.test.js", "test:provider-poll-limiter": "node scripts/providerPollLimiterContract.test.js", - "test": "npm run test:community-routes && npm run test:enterprise-video-pricing && npm run test:key-manager && npm run test:provider-poll-limiter" + "test:task-progress-contract": "node scripts/taskProgressContract.test.js", + "test": "npm run test:community-routes && npm run test:enterprise-video-pricing && npm run test:key-manager && npm run test:provider-poll-limiter && npm run test:task-progress-contract" }, "dependencies": { "alipay-sdk": "^4.14.0", diff --git a/scripts/taskProgressContract.test.js b/scripts/taskProgressContract.test.js new file mode 100644 index 0000000..1ebbf60 --- /dev/null +++ b/scripts/taskProgressContract.test.js @@ -0,0 +1,90 @@ +const assert = require("node:assert/strict"); + +const { + DEFAULT_IMAGE_EXPECTED_DURATION_MS, + DEFAULT_VIDEO_EXPECTED_DURATION_MS, + PROGRESS_SOURCE_ESTIMATED, + PROGRESS_SOURCE_REAL, + formatTaskProgressPayload, + getExpectedDurationMs, + parseTaskParams, +} = require("../src/taskProgressContract"); + +const createdAt = "2026-06-10T08:00:00.000Z"; + +{ + const payload = formatTaskProgressPayload({ + id: 101, + status: "completed", + progress: 0, + created_at: createdAt, + params_json: "{}", + result_url: "https://oss.example/result.png", + }); + + assert.equal(payload.taskId, 101); + assert.equal(payload.progress, 100); + assert.equal(payload.progressSource, PROGRESS_SOURCE_REAL); + assert.equal(payload.stage, "\u5b8c\u6210"); + assert.equal(payload.startedAt, createdAt); + assert.equal(payload.resultUrl, "https://oss.example/result.png"); +} + +{ + const payload = formatTaskProgressPayload({ + id: 102, + status: "running", + progress: 43, + progress_source: PROGRESS_SOURCE_REAL, + created_at: createdAt, + params_json: JSON.stringify({ model: "kling-3.0-dashscope", duration: 5 }), + type: "video", + }); + + assert.equal(payload.progress, 43); + assert.equal(payload.progressSource, PROGRESS_SOURCE_REAL); + assert.equal(payload.stage, "\u751f\u6210\u4e2d"); + assert.equal(payload.expectedDurationMs, 300_000); +} + +{ + const payload = formatTaskProgressPayload({ + id: 103, + status: "running", + progress: 0, + created_at: createdAt, + params_json: JSON.stringify({ + requestedModel: "nano-banana-pro", + referenceUrls: ["https://oss.example/a.png", "https://oss.example/b.png"], + }), + type: "image", + }); + + assert.equal(payload.progressSource, PROGRESS_SOURCE_ESTIMATED); + assert.equal(payload.stage, "\u5df2\u63d0\u4ea4"); + assert.equal(payload.expectedDurationMs, 250_000); +} + +{ + const expectedDurationMs = getExpectedDurationMs({ + type: "video", + params_json: JSON.stringify({ model: "kling-3.0-dashscope", duration: 10 }), + }); + + assert.equal(expectedDurationMs, 400_000); +} + +{ + assert.deepEqual(parseTaskParams("{bad json"), {}); + assert.deepEqual(parseTaskParams({ model: "gpt-image-2" }), { model: "gpt-image-2" }); + assert.equal( + getExpectedDurationMs({ type: "image", params_json: "{}" }), + DEFAULT_IMAGE_EXPECTED_DURATION_MS, + ); + assert.equal( + getExpectedDurationMs({ type: "video", params_json: "{}" }), + DEFAULT_VIDEO_EXPECTED_DURATION_MS, + ); +} + +console.log("task progress contract tests passed"); diff --git a/src/aiTaskWorker.js b/src/aiTaskWorker.js index f061ed8..501a9e5 100644 --- a/src/aiTaskWorker.js +++ b/src/aiTaskWorker.js @@ -6,6 +6,12 @@ const { pool } = require("./db"); const { refundTaskBillingOnFailure } = require("./billing"); const { putObject, isOssConfigured } = require("./ossClient"); const { withProviderPollSlot } = require("./providerPollLimiter"); +const { + PROGRESS_SOURCE_ESTIMATED, + PROGRESS_SOURCE_REAL, + formatTaskProgressPayload, + normalizeProgressSource, +} = require("./taskProgressContract"); const taskEvents = new EventEmitter(); taskEvents.setMaxListeners(200); @@ -33,13 +39,7 @@ function normalizeTaskProgress(value) { } function formatTaskEvent(row) { - return { - taskId: row.id, - status: row.status, - progress: row.progress, - resultUrl: row.result_url || null, - error: row.error || null, - }; + return formatTaskProgressPayload(row); } function emitTaskEvent(event) { @@ -261,6 +261,10 @@ async function updateTaskInDb(taskId, updates) { const progress = normalizeTaskProgress(nextUpdates.progress); if (progress !== undefined) { fields.push(`progress = $${idx++}`); values.push(progress); } } + if (nextUpdates.progressSource !== undefined) { + const progressSource = normalizeProgressSource(nextUpdates.progressSource); + if (progressSource) { fields.push(`progress_source = $${idx++}`); values.push(progressSource); } + } if (nextUpdates.resultUrl !== undefined) { fields.push(`result_url = $${idx++}`); values.push(nextUpdates.resultUrl); } if (nextUpdates.error !== undefined) { fields.push(`error = $${idx++}`); values.push(nextUpdates.error); } if (nextUpdates.providerTaskId !== undefined) { fields.push(`provider_task_id = $${idx++}`); values.push(nextUpdates.providerTaskId); } @@ -691,7 +695,7 @@ async function pollGrsaiImage(_taskId, providerTaskId, apiKey, baseUrl, resultEn }); if (!ok) { console.warn(`[grsai-poll] task ${_taskId} fetch not ok, url=${url}`); - return { status: "running", progress: 50 }; + return { status: "running", progressSource: PROGRESS_SOURCE_ESTIMATED }; } const data = asObject(json?.data) || json; @@ -709,17 +713,17 @@ async function pollGrsaiImage(_taskId, providerTaskId, apiKey, baseUrl, resultEn const resultUrl = extractImageUrl(json); console.info(`[grsai-poll] task ${_taskId} status=${status} resultUrl=${resultUrl ? "yes" : "no"} raw=${JSON.stringify(json).slice(0, 300)}`); if (resultUrl) { - return { status: "completed", progress: 100, resultUrl }; + return { status: "completed", progress: 100, progressSource: PROGRESS_SOURCE_REAL, resultUrl }; } if (isCompletedStatus(status)) { const completedUrl = extractImageUrl(json); if (!completedUrl) return { status: "failed", error: "Image generation completed without a result url" }; - return { status: "completed", progress: 100, resultUrl: completedUrl }; + return { status: "completed", progress: 100, progressSource: PROGRESS_SOURCE_REAL, resultUrl: completedUrl }; } if (isFailedStatus(status)) { return { status: "failed", error: extractErrorMessage(json, "Image generation failed") }; } - return { status: "running", progress: Math.min(90, 30 + Math.random() * 40) }; + return { status: "running", progressSource: PROGRESS_SOURCE_ESTIMATED }; } async function pollDashscopeImage(_taskId, providerTaskId, apiKey) { @@ -728,19 +732,19 @@ async function pollDashscopeImage(_taskId, providerTaskId, apiKey) { Authorization: `Bearer ${apiKey}`, Accept: "application/json", }); - if (!ok) return { status: "running", progress: 50 }; + if (!ok) return { status: "running", progressSource: PROGRESS_SOURCE_ESTIMATED }; const output = asObject(json?.output) || {}; const status = normalizeStatus(output.task_status || json?.task_status); const resultUrl = extractImageUrl(json); if (isCompletedStatus(status)) { if (!resultUrl) return { status: "failed", error: "DashScope image generation completed without a result url" }; - return { status: "completed", progress: 100, resultUrl }; + return { status: "completed", progress: 100, progressSource: PROGRESS_SOURCE_REAL, resultUrl }; } if (isFailedStatus(status)) { return { status: "failed", error: extractErrorMessage(json, "DashScope image generation failed") }; } - return { status: "running", progress: Math.min(90, 30 + Math.random() * 40) }; + return { status: "running", progressSource: PROGRESS_SOURCE_ESTIMATED }; } function base64Url(input) { @@ -808,7 +812,7 @@ function getPollRequest(providerTaskId, apiKey, providerConfig) { async function pollVideoTask(_taskId, providerTaskId, apiKey, providerConfig) { const { url, headers } = getPollRequest(providerTaskId, apiKey, providerConfig); const { ok, json } = await fetchJson(url, headers); - if (!ok) return { status: "running", progress: 50 }; + if (!ok) return { status: "running", progressSource: PROGRESS_SOURCE_ESTIMATED }; const data = asObject(json?.data) || json; const output = asObject(json?.output) || {}; @@ -822,13 +826,15 @@ async function pollVideoTask(_taskId, providerTaskId, apiKey, providerConfig) { const resultUrl = extractVideoUrl(json); if (isCompletedStatus(status) || resultUrl) { - return { status: "completed", progress: 100, resultUrl: resultUrl || null }; + return { status: "completed", progress: 100, progressSource: PROGRESS_SOURCE_REAL, resultUrl: resultUrl || null }; } if (isFailedStatus(status)) { return { status: "failed", error: extractErrorMessage(json, "Video generation failed") }; } const progress = Number(data.progress || output.progress); - return { status: "running", progress: Number.isFinite(progress) ? Math.min(95, progress) : Math.min(90, 30 + Math.random() * 30) }; + return Number.isFinite(progress) + ? { status: "running", progress: Math.min(95, progress), progressSource: PROGRESS_SOURCE_REAL } + : { status: "running", progressSource: PROGRESS_SOURCE_ESTIMATED }; } function getMaxPollAttempts(type, providerConfig) { diff --git a/src/dbSetup.js b/src/dbSetup.js index c1ee1c8..dcd5371 100644 --- a/src/dbSetup.js +++ b/src/dbSetup.js @@ -353,6 +353,10 @@ async function migrateGenerationTasksBillingColumns(client) { ); } +async function migrateGenerationTaskProgressContract() { + await addColumnIfMissing("generation_tasks", "progress_source TEXT"); +} + async function ensureModelPriceSeed() { const columns = await getColumnNames("model_prices"); const useMills = columns.includes("input_price_mills"); @@ -519,6 +523,7 @@ async function migrateGenerationTasksSchema(client) { params_json TEXT NOT NULL DEFAULT '{}', result_url VARCHAR(2000), progress INTEGER NOT NULL DEFAULT 0, + progress_source TEXT, error TEXT, dedupe_key VARCHAR(256), source_device_id VARCHAR(128), @@ -959,6 +964,7 @@ async function ensureSchema() { await runMigration("030_generation_tasks_user_status_index", migrateGenerationTasksUserStatusIndex); await runMigration("031_generation_tasks_billing_columns", migrateGenerationTasksBillingColumns); await runMigration("032_ecommerce_video_history", migrateEcommerceVideoHistorySchema); + await runMigration("033_generation_task_progress_contract", migrateGenerationTaskProgressContract); await ensureModelPriceSeed(); } diff --git a/src/routes/ai.js b/src/routes/ai.js index 3666f34..6a9e044 100644 --- a/src/routes/ai.js +++ b/src/routes/ai.js @@ -32,6 +32,10 @@ const { normalizeImageUpscaleFactor, normalizeVideoStyleTransformOptions, } = require("../aiUpscaleHelpers"); +const { + formatTaskProgressPayload, + parseTaskParams, +} = require("../taskProgressContract"); const GRSAI_IMAGE_QUALITY_MODEL_OVERRIDES = new Map([ ["gpt-image-2", "1K"], @@ -433,16 +437,8 @@ function sanitizeUpstreamError(value, fallback = "上游服务暂时不可用, return compact.slice(0, 320); } -function parseTaskParams(value) { - if (!value || typeof value !== "string") return {}; - try { - return JSON.parse(value); - } catch { - return {}; - } -} - function formatAiTaskRow(row) { + const progressPayload = formatTaskProgressPayload(row); return { taskId: String(row.id), projectId: row.project_id, @@ -450,9 +446,13 @@ function formatAiTaskRow(row) { clientQueueId: row.client_queue_id || null, type: row.type, status: row.status, - progress: Number(row.progress || 0), - resultUrl: row.result_url || null, - error: row.error || null, + progress: progressPayload.progress, + progressSource: progressPayload.progressSource, + stage: progressPayload.stage, + startedAt: progressPayload.startedAt, + expectedDurationMs: progressPayload.expectedDurationMs, + resultUrl: progressPayload.resultUrl, + error: progressPayload.error, params: parseTaskParams(row.params_json), createdAt: row.created_at, updatedAt: row.updated_at, @@ -1779,13 +1779,7 @@ function registerAiRoutes(router) { ).catch(() => {}); } - const event = { - taskId: row.id, - status: row.status, - progress: Number(row.progress || 0), - resultUrl: row.result_url || null, - error: row.error || null, - }; + const event = formatTaskProgressPayload(row); emit(event); return { found: true, @@ -1810,13 +1804,7 @@ function registerAiRoutes(router) { }); const row = rows[0]; - const initial = { - taskId: row.id, - status: row.status, - progress: row.progress, - resultUrl: row.result_url || null, - error: row.error || null, - }; + const initial = formatTaskProgressPayload(row); res.write(`data: ${JSON.stringify(initial)}\n\n`); if (["completed", "failed", "cancelled"].includes(row.status)) { diff --git a/src/taskProgressContract.js b/src/taskProgressContract.js new file mode 100644 index 0000000..60b00ee --- /dev/null +++ b/src/taskProgressContract.js @@ -0,0 +1,134 @@ +"use strict"; + +const PROGRESS_SOURCE_REAL = "real"; +const PROGRESS_SOURCE_ESTIMATED = "estimated"; + +const DEFAULT_IMAGE_EXPECTED_DURATION_MS = 120_000; +const DEFAULT_VIDEO_EXPECTED_DURATION_MS = 240_000; +const DEFAULT_SUPER_RESOLUTION_EXPECTED_DURATION_MS = 180_000; + +function parseTaskParams(value) { + if (!value) return {}; + if (typeof value === "object" && !Array.isArray(value)) return value; + if (typeof value !== "string") return {}; + try { + const parsed = JSON.parse(value); + return parsed && typeof parsed === "object" && !Array.isArray(parsed) ? parsed : {}; + } catch { + return {}; + } +} + +function normalizeProgressSource(value) { + const source = String(value || "").trim().toLowerCase(); + if (source === PROGRESS_SOURCE_REAL) return PROGRESS_SOURCE_REAL; + if (source === PROGRESS_SOURCE_ESTIMATED) return PROGRESS_SOURCE_ESTIMATED; + return null; +} + +function inferProgressSource(row) { + const explicit = normalizeProgressSource(row?.progress_source || row?.progressSource); + if (explicit) return explicit; + if (row?.status === "completed") return PROGRESS_SOURCE_REAL; + return PROGRESS_SOURCE_ESTIMATED; +} + +function normalizePositiveNumber(value) { + const numeric = Number(value); + return Number.isFinite(numeric) && numeric > 0 ? numeric : null; +} + +function normalizeProgress(value, status) { + if (status === "completed") return 100; + const numeric = Number(value); + if (!Number.isFinite(numeric)) return 0; + return Math.max(0, Math.min(100, Math.round(numeric))); +} + +function getExpectedImageDurationMs(model, params) { + const normalized = String(model || "").toLowerCase(); + let durationMs = DEFAULT_IMAGE_EXPECTED_DURATION_MS; + + if (normalized.includes("nano-banana-pro")) durationMs = 220_000; + else if (normalized.includes("nano-banana-2")) durationMs = 180_000; + else if (normalized.includes("nano-banana-fast")) durationMs = 90_000; + else if (normalized.includes("wan2.7-image-pro")) durationMs = 180_000; + else if (normalized.includes("wan2.7-image")) durationMs = 120_000; + else if (normalized.includes("gpt-image")) durationMs = 120_000; + + const referenceCount = Array.isArray(params.referenceUrls) ? params.referenceUrls.filter(Boolean).length : 0; + if (referenceCount > 0) durationMs += Math.min(60_000, referenceCount * 15_000); + return durationMs; +} + +function getExpectedVideoDurationMs(model, params) { + const normalized = String(model || "").toLowerCase(); + const seconds = normalizePositiveNumber(params.duration || params.durationSeconds) || 5; + let durationMs = DEFAULT_VIDEO_EXPECTED_DURATION_MS; + + if (normalized.includes("kling")) durationMs = 300_000; + else if (normalized.includes("happyhorse")) durationMs = 240_000; + else if (normalized.includes("wan2.7") || normalized.includes("wanxiang")) durationMs = 240_000; + else if (normalized.includes("vidu") || normalized.includes("pixverse")) durationMs = 240_000; + else if (normalized.includes("aliyun-video-super-resolve") || normalized.includes("video-style-transform")) { + durationMs = DEFAULT_SUPER_RESOLUTION_EXPECTED_DURATION_MS; + } + + if (seconds > 5) durationMs += Math.min(240_000, Math.ceil(seconds - 5) * 20_000); + return durationMs; +} + +function getExpectedDurationMs(rowOrTask) { + const params = parseTaskParams(rowOrTask?.params_json || rowOrTask?.params); + const model = params.requestedModel || params.model || rowOrTask?.model || ""; + + if (params.operation === "image-edit" || params.function || String(model).includes("imageedit")) { + return DEFAULT_SUPER_RESOLUTION_EXPECTED_DURATION_MS; + } + + if (rowOrTask?.type === "video") return getExpectedVideoDurationMs(model, params); + return getExpectedImageDurationMs(model, params); +} + +function deriveTaskStage(row) { + const status = String(row?.status || ""); + if (status === "pending") return "\u6392\u961f\u4e2d"; + if (status === "completed") return "\u5b8c\u6210"; + if (status === "failed") return "\u5931\u8d25"; + if (status === "cancelled") return "\u5df2\u53d6\u6d88"; + if (status !== "running") return "\u5904\u7406\u4e2d"; + + const progress = Number(row?.progress || 0); + if (progress >= 90) return "\u7ed3\u679c\u5904\u7406\u4e2d"; + if (progress >= 15) return "\u751f\u6210\u4e2d"; + return "\u5df2\u63d0\u4ea4"; +} + +function formatTaskProgressPayload(row) { + const progress = normalizeProgress(row.progress, row.status); + return { + taskId: row.id, + status: row.status, + progress, + progressSource: inferProgressSource(row), + stage: deriveTaskStage(row), + startedAt: row.created_at, + expectedDurationMs: getExpectedDurationMs(row), + resultUrl: row.result_url || null, + error: row.error || null, + }; +} + +module.exports = { + DEFAULT_IMAGE_EXPECTED_DURATION_MS, + DEFAULT_SUPER_RESOLUTION_EXPECTED_DURATION_MS, + DEFAULT_VIDEO_EXPECTED_DURATION_MS, + PROGRESS_SOURCE_ESTIMATED, + PROGRESS_SOURCE_REAL, + deriveTaskStage, + formatTaskProgressPayload, + getExpectedDurationMs, + inferProgressSource, + normalizeProgressSource, + parseTaskParams, +};