diff --git a/src/api/publicPricingClient.test.ts b/src/api/publicPricingClient.test.ts new file mode 100644 index 0000000..c65d825 --- /dev/null +++ b/src/api/publicPricingClient.test.ts @@ -0,0 +1,142 @@ +import { describe, expect, it } from "../test/testHarness"; + +import { + normalizeEnterpriseVideoPricingConfig, + normalizePublicModelPrice, + normalizePublicModelPrices, + normalizePublicPricingPayload, +} from "./publicPricingClient"; + +describe("publicPricingClient", () => { + it("normalizes camelCase public model price payloads", () => { + expect( + normalizePublicModelPrice({ + id: 1, + modelKey: "gpt-4o", + displayName: "GPT-4o", + category: "text", + pricingType: "token", + inputPriceMills: 27, + outputPriceMills: 108, + flatPriceMills: null, + currency: "CNY", + enabled: true, + }), + ).toEqual({ + id: 1, + modelKey: "gpt-4o", + displayName: "GPT-4o", + category: "text", + pricingType: "token", + inputPriceMills: 27, + outputPriceMills: 108, + flatPriceMills: null, + currency: "CNY", + enabled: true, + createdAt: undefined, + updatedAt: undefined, + }); + }); + + it("normalizes snake_case public model price payloads inside containers", () => { + expect( + normalizePublicModelPrices({ + prices: [ + { + model_key: "deepseek-chat", + display_name: "DeepSeek Chat", + pricing_type: "token", + input_price_mills: "2", + output_price_mills: "8", + flat_price_mills: "0", + enabled: 1, + }, + { display_name: "missing key" }, + ], + }), + ).toEqual([ + { + id: undefined, + modelKey: "deepseek-chat", + displayName: "DeepSeek Chat", + category: undefined, + pricingType: "token", + inputPriceMills: 2, + outputPriceMills: 8, + flatPriceMills: 0, + currency: "CNY", + enabled: true, + createdAt: undefined, + updatedAt: undefined, + }, + ]); + }); + + it("normalizes public pricing payloads with model prices and enterprise video pricing", () => { + expect( + normalizePublicPricingPayload({ + modelPrices: [ + { + modelKey: "qwen-turbo", + pricingType: "token", + inputPriceMills: 2, + outputPriceMills: 6, + }, + ], + enterpriseVideoPricing: { + currency: "CNY", + creditsPerCny: 100, + billingUnit: "per_second", + defaultResolution: "1080P", + resolutions: ["720P", "1080P"], + rules: [ + { + id: "happyhorse", + modelIncludes: ["happyhorse"], + rates: { "720P": 0.72, "1080P": 1.28 }, + }, + ], + }, + }), + ).toEqual({ + modelPrices: [ + { + id: undefined, + modelKey: "qwen-turbo", + displayName: undefined, + category: undefined, + pricingType: "token", + inputPriceMills: 2, + outputPriceMills: 6, + flatPriceMills: null, + currency: "CNY", + enabled: true, + createdAt: undefined, + updatedAt: undefined, + }, + ], + enterpriseVideoPricing: { + currency: "CNY", + creditsPerCny: 100, + billingUnit: "per_second", + defaultResolution: "1080P", + resolutions: ["720P", "1080P"], + rules: [ + { + id: "happyhorse", + modelIncludes: ["happyhorse"], + rates: { "720P": 0.72, "1080P": 1.28 }, + }, + ], + }, + }); + }); + + it("rejects malformed enterprise video pricing configs", () => { + expect( + normalizeEnterpriseVideoPricingConfig({ + rules: [{ id: "broken", modelIncludes: [], rates: {} }], + }), + ).toEqual(null); + }); +}); diff --git a/src/api/publicPricingClient.ts b/src/api/publicPricingClient.ts new file mode 100644 index 0000000..dc20e3d --- /dev/null +++ b/src/api/publicPricingClient.ts @@ -0,0 +1,236 @@ +import { isOptionalApiRouteMissing } from "./apiErrorUtils"; +import { isRecord, serverRequest } from "./serverConnection"; +import type { EnterpriseVideoPricingConfig, EnterpriseVideoPricingRule } from "../utils/enterpriseVideoPolicy"; + +export interface PublicModelPrice { + id?: number | string; + modelKey: string; + displayName?: string; + category?: string; + pricingType?: string; + inputPriceMills: number | null; + outputPriceMills: number | null; + flatPriceMills: number | null; + currency: string; + enabled: boolean; + createdAt?: string; + updatedAt?: string; +} + +export interface PublicPricingPayload { + modelPrices: PublicModelPrice[]; + enterpriseVideoPricing: EnterpriseVideoPricingConfig | null; +} + +function readString( + record: Record, + keys: string[], +): string | undefined { + for (const key of keys) { + const value = record[key]; + if (typeof value === "string" && value.trim()) return value.trim(); + } + return undefined; +} + +function readNumber( + record: Record, + keys: string[], +): number | null { + for (const key of keys) { + const value = record[key]; + const parsed = + typeof value === "number" + ? value + : typeof value === "string" + ? Number(value) + : NaN; + if (Number.isFinite(parsed)) return parsed; + } + return null; +} + +function readBoolean( + record: Record, + keys: string[], + fallback: boolean, +): boolean { + for (const key of keys) { + const value = record[key]; + if (typeof value === "boolean") return value; + if (typeof value === "number") return value !== 0; + if (typeof value === "string") { + const normalized = value.trim().toLowerCase(); + if (["1", "true", "yes", "enabled"].includes(normalized)) return true; + if (["0", "false", "no", "disabled"].includes(normalized)) return false; + } + } + return fallback; +} + +function readStringArray(record: Record, keys: string[]): string[] { + for (const key of keys) { + const value = record[key]; + if (!Array.isArray(value)) continue; + return value + .map((item) => (typeof item === "string" ? item.trim() : "")) + .filter(Boolean); + } + return []; +} + +function normalizeRateMap(raw: unknown): Record | null { + if (!isRecord(raw)) return null; + const result: Record = {}; + for (const [key, value] of Object.entries(raw)) { + const parsed = typeof value === "number" ? value : typeof value === "string" ? Number(value) : NaN; + if (Number.isFinite(parsed) && parsed >= 0) result[key] = parsed; + } + return Object.keys(result).length ? result : null; +} + +function normalizeEnterpriseVideoPricingRule(raw: unknown): EnterpriseVideoPricingRule | null { + if (!isRecord(raw)) return null; + const id = readString(raw, ["id", "key", "name"]); + const modelIncludes = readStringArray(raw, ["modelIncludes", "model_includes", "modelPatterns", "model_patterns"]); + const rates = normalizeRateMap(raw.rates); + if (!id || modelIncludes.length === 0 || !rates) return null; + + const when = isRecord(raw.when) + ? { + ...(typeof raw.when.muted === "boolean" ? { muted: raw.when.muted } : {}), + ...(typeof raw.when.hasReferenceVideo === "boolean" + ? { hasReferenceVideo: raw.when.hasReferenceVideo } + : {}), + } + : undefined; + + return { + id, + modelIncludes, + ...(when && Object.keys(when).length ? { when } : {}), + rates, + }; +} + +export function normalizePublicModelPrice( + raw: unknown, +): PublicModelPrice | null { + if (!isRecord(raw)) return null; + + const modelKey = readString(raw, ["modelKey", "model_key", "key", "model"]); + if (!modelKey) return null; + + const displayName = readString(raw, ["displayName", "display_name", "name"]); + const category = readString(raw, ["category", "type"]); + const pricingType = readString(raw, ["pricingType", "pricing_type"]); + const currency = readString(raw, ["currency"]) || "CNY"; + const createdAt = readString(raw, ["createdAt", "created_at"]); + const updatedAt = readString(raw, ["updatedAt", "updated_at"]); + const idValue = raw.id; + + return { + id: + typeof idValue === "number" || typeof idValue === "string" + ? idValue + : undefined, + modelKey, + displayName, + category, + pricingType, + inputPriceMills: readNumber(raw, ["inputPriceMills", "input_price_mills"]), + outputPriceMills: readNumber(raw, [ + "outputPriceMills", + "output_price_mills", + ]), + flatPriceMills: readNumber(raw, ["flatPriceMills", "flat_price_mills"]), + currency, + enabled: readBoolean(raw, ["enabled", "is_enabled"], true), + createdAt, + updatedAt, + }; +} + +export function normalizePublicModelPrices( + payload: unknown, +): PublicModelPrice[] { + const rawPrices = Array.isArray(payload) + ? payload + : isRecord(payload) && Array.isArray(payload.prices) + ? payload.prices + : isRecord(payload) && Array.isArray(payload.modelPrices) + ? payload.modelPrices + : isRecord(payload) && Array.isArray(payload.model_prices) + ? payload.model_prices + : isRecord(payload) && Array.isArray(payload.models) + ? payload.models + : []; + + return rawPrices + .map((item) => normalizePublicModelPrice(item)) + .filter((item): item is PublicModelPrice => Boolean(item)); +} + +export function normalizeEnterpriseVideoPricingConfig(raw: unknown): EnterpriseVideoPricingConfig | null { + if (!isRecord(raw)) return null; + const rules = Array.isArray(raw.rules) + ? raw.rules + .map((item) => normalizeEnterpriseVideoPricingRule(item)) + .filter((item): item is EnterpriseVideoPricingRule => Boolean(item)) + : []; + if (rules.length === 0) return null; + + const creditsPerCny = readNumber(raw, ["creditsPerCny", "credits_per_cny"]); + const defaultResolution = readString(raw, ["defaultResolution", "default_resolution"]); + const billingUnit = readString(raw, ["billingUnit", "billing_unit"]); + const currency = readString(raw, ["currency"]); + const resolutions = readStringArray(raw, ["resolutions", "supportedResolutions", "supported_resolutions"]); + + return { + ...(currency ? { currency } : {}), + ...(creditsPerCny !== null ? { creditsPerCny } : {}), + ...(billingUnit ? { billingUnit } : {}), + ...(defaultResolution ? { defaultResolution } : {}), + ...(resolutions.length ? { resolutions } : {}), + rules, + }; +} + +export function normalizePublicPricingPayload(payload: unknown): PublicPricingPayload { + const enterpriseVideoPricingRaw = + isRecord(payload) && (payload.enterpriseVideoPricing ?? payload.enterprise_video_pricing); + + return { + modelPrices: normalizePublicModelPrices(payload), + enterpriseVideoPricing: normalizeEnterpriseVideoPricingConfig(enterpriseVideoPricingRaw), + }; +} + +let cachedPricing: PublicPricingPayload | null = null; +let pricesRouteMissing = false; + +export const publicPricingClient = { + async getPricing(): Promise { + if (cachedPricing) return cachedPricing; + if (pricesRouteMissing) return { modelPrices: [], enterpriseVideoPricing: null }; + + try { + const payload = await serverRequest("prices", { + fallbackMessage: "Model prices request failed", + }); + cachedPricing = normalizePublicPricingPayload(payload); + return cachedPricing; + } catch (error) { + if (isOptionalApiRouteMissing(error)) { + pricesRouteMissing = true; + return { modelPrices: [], enterpriseVideoPricing: null }; + } + throw error; + } + }, + + async getPrices(): Promise { + const pricing = await publicPricingClient.getPricing(); + return pricing.modelPrices; + }, +}; diff --git a/src/features/canvas/CanvasPage.tsx b/src/features/canvas/CanvasPage.tsx index c424ba6..8ba9351 100644 --- a/src/features/canvas/CanvasPage.tsx +++ b/src/features/canvas/CanvasPage.tsx @@ -371,12 +371,19 @@ function CanvasPage({ const textNodeIdRef = useRef(9); const imageNodeIdRef = useRef(1); const videoNodeIdRef = useRef(1); + const objectUrlsRef = useRef(new Set()); + const trackObjectUrl = (file: Blob) => { + const url = URL.createObjectURL(file); + objectUrlsRef.current.add(url); + return url; + }; const { pushSnapshot, undo, redo } = useCanvasHistory(); const { textGenerationState, imageGenerationState, videoGenerationState, generationToast, setGenerationToast, imageGenerationInFlightRef, videoGenerationInFlightRef, textGenerationInFlightRef, textGenerationAbortControllersRef, + imageGenerationAbortRef, videoGenerationAbortRef, canvasGenKeepaliveRestoredRef, setTextGenerationStatus, setImageGenerationStatus, setVideoGenerationStatus, restoreKeepaliveTasks, resetGenerationState, @@ -527,6 +534,7 @@ function CanvasPage({ const autoSaveStatusTimerRef = useRef(null); useEffect(() => { + const objectUrls = objectUrlsRef.current; return () => { if (canvasAutoSaveTimerRef.current !== null) window.clearTimeout(canvasAutoSaveTimerRef.current); if (canvasAutoSaveRetryTimerRef.current !== null) window.clearTimeout(canvasAutoSaveRetryTimerRef.current); @@ -535,6 +543,8 @@ function CanvasPage({ if (canvasAutoSaveIdleHandleRef.current !== null && "cancelIdleCallback" in window) { window.cancelIdleCallback(canvasAutoSaveIdleHandleRef.current); } + objectUrls.forEach((url) => URL.revokeObjectURL(url)); + objectUrls.clear(); }; }, []); @@ -1691,12 +1701,15 @@ function CanvasPage({ const quality = resolveImageQuality(model, imageNode.imageSize || ""); imageGenerationInFlightRef.current.add(nodeId); + const abortRef = { current: false }; + imageGenerationAbortRef.current.set(nodeId, abortRef); setImageGenerationStatus(nodeId, { status: "submitting", message: "正在提交生成", progress: 8 }); setGenerationToast("图片正在生成"); let task: Awaited> | null = null; try { const referenceUrls = await resolveConnectedImageReferenceUrls("image", nodeId, imageNode); + if (abortRef.current) return; const taskInput: CreatePreviewTaskInput = { title: imageNode.title || "图片节点生成", type: "image", @@ -1732,7 +1745,8 @@ function CanvasPage({ ? "图片生成完成" : "图片生成失败"; setImageGenerationStatus(nodeId, { status: "running", message: statusLabel, progress }); - })); + }, abortRef)); + if (abortRef.current || !outputUrl) return; setImageGenerationStatus(nodeId, { status: "success", message: "生成完成", progress: 100 }); removeCanvasGenKeepalive(task.id); const immediateAssetRef = createCanvasAssetRefFromGeneratedResult({ @@ -1794,13 +1808,15 @@ function CanvasPage({ ); } } catch (error) { + if (abortRef.current) return; setImageGenerationStatus(nodeId, { status: "error", message: error instanceof Error ? error.message : "图片生成失败", }); } finally { imageGenerationInFlightRef.current.delete(nodeId); - if (task?.id) removeCanvasGenKeepalive(task.id); + imageGenerationAbortRef.current.delete(nodeId); + if (task?.id && !abortRef.current) removeCanvasGenKeepalive(task.id); } }; @@ -1843,12 +1859,15 @@ function CanvasPage({ const duration = Number(videoNode.duration) || 4; videoGenerationInFlightRef.current.add(nodeId); + const abortRef = { current: false }; + videoGenerationAbortRef.current.set(nodeId, abortRef); setVideoGenerationStatus(nodeId, { status: "submitting", message: "正在提交视频生成", progress: 8 }); setGenerationToast("视频正在生成"); let task: Awaited> | null = null; try { const referenceUrls = await resolveConnectedImageReferenceUrls("video", nodeId); + if (abortRef.current) return; if (videoNode.videoMode === "img2video" && referenceUrls.length === 0) { throw new Error("图生视频需要先连接至少一个可用的图片节点"); } @@ -1892,7 +1911,8 @@ function CanvasPage({ ? "视频生成完成" : "视频生成失败"; setVideoGenerationStatus(nodeId, { status: "running", message: statusLabel, progress }); - })); + }, abortRef)); + if (abortRef.current || !outputUrl) return; setVideoGenerationStatus(nodeId, { status: "success", message: "视频生成完成", progress: 100 }); removeCanvasGenKeepalive(taskId); const immediateAssetRef = createCanvasAssetRefFromGeneratedResult({ @@ -1948,13 +1968,15 @@ function CanvasPage({ ); } } catch (error) { + if (abortRef.current) return; setVideoGenerationStatus(nodeId, { status: "error", message: error instanceof Error ? error.message : "视频生成失败", }); } finally { videoGenerationInFlightRef.current.delete(nodeId); - if (task?.id) removeCanvasGenKeepalive(task.id); + videoGenerationAbortRef.current.delete(nodeId); + if (task?.id && !abortRef.current) removeCanvasGenKeepalive(task.id); } }; @@ -1965,7 +1987,7 @@ function CanvasPage({ const file = event.target.files?.[0]; event.target.value = ""; if (!file) return; - const imageUrl = URL.createObjectURL(file); + const imageUrl = trackObjectUrl(file); if (pendingImageToImageNodeId) { const sourceNode = imageNodes.find((node) => node.id === pendingImageToImageNodeId); if (sourceNode) { @@ -2047,7 +2069,7 @@ function CanvasPage({ let offsetX = 0; let offsetY = 0; for (const file of files) { - const imageUrl = URL.createObjectURL(file); + const imageUrl = trackObjectUrl(file); addImageNode(imageUrl, file.name, { x: dropPosition.x + offsetX, y: dropPosition.y + offsetY, @@ -2103,7 +2125,7 @@ function CanvasPage({ let offsetX = 0; let offsetY = 0; for (const file of files) { - const imageUrl = URL.createObjectURL(file); + const imageUrl = trackObjectUrl(file); addImageNode(imageUrl, file.name, { x: sourceNode.position.x + sourceNode.size.width + 40 + offsetX, y: sourceNode.position.y + offsetY, @@ -5279,7 +5301,7 @@ function CanvasPage({ onChange={(event) => { const file = event.target.files?.[0]; if (!file) return; - setAssetCoverUrl(URL.createObjectURL(file)); + setAssetCoverUrl(trackObjectUrl(file)); setCoverSourceOpen(false); }} /> diff --git a/src/features/canvas/canvasUtils.ts b/src/features/canvas/canvasUtils.ts index 08a2ab6..72fc37c 100644 --- a/src/features/canvas/canvasUtils.ts +++ b/src/features/canvas/canvasUtils.ts @@ -252,28 +252,40 @@ export function blobToDataUrl(blob: Blob) { }); } -export async function waitForImageTaskResult(taskId: string, onStatus?: (status: AiTaskStatus) => void) { +export async function waitForImageTaskResult( + taskId: string, + onStatus?: (status: AiTaskStatus) => void, + abortRef?: { current: boolean }, +) { const resultUrl = await waitForTask(taskId, { kind: "image", + abortRef, onProgress: (e) => { if (onStatus) { onStatus({ taskId, status: e.status, progress: e.progress, resultUrl: e.resultUrl ?? undefined, error: e.error ?? undefined } as AiTaskStatus); } }, }); + if (abortRef?.current) return ""; if (!resultUrl) throw new Error("生成任务已完成,但服务器没有返回结果地址,请稍后重试"); return resultUrl; } -export async function waitForVideoTaskResult(taskId: string, onStatus?: (status: AiTaskStatus) => void) { +export async function waitForVideoTaskResult( + taskId: string, + onStatus?: (status: AiTaskStatus) => void, + abortRef?: { current: boolean }, +) { const resultUrl = await waitForTask(taskId, { kind: "video", + abortRef, onProgress: (e) => { if (onStatus) { onStatus({ taskId, status: e.status, progress: e.progress, resultUrl: e.resultUrl ?? undefined, error: e.error ?? undefined } as AiTaskStatus); } }, }); + if (abortRef?.current) return ""; if (!resultUrl) throw new Error("视频生成任务已完成,但服务器没有返回结果地址,请稍后重试"); return resultUrl; } diff --git a/src/features/canvas/useCanvasGeneration.ts b/src/features/canvas/useCanvasGeneration.ts index 5f5b135..222ffb0 100644 --- a/src/features/canvas/useCanvasGeneration.ts +++ b/src/features/canvas/useCanvasGeneration.ts @@ -1,4 +1,4 @@ -import { type Dispatch, type SetStateAction, useEffect, useRef, useState } from "react"; +import { type Dispatch, type SetStateAction, useCallback, useEffect, useRef, useState } from "react"; import type { CanvasImageGenerationState, CanvasImageNode, @@ -66,6 +66,8 @@ export function useCanvasGeneration(params: UseCanvasGenerationParams) { const videoGenerationInFlightRef = useRef(new Set()); const textGenerationInFlightRef = useRef(new Set()); const textGenerationAbortControllersRef = useRef(new Map()); + const imageGenerationAbortRef = useRef(new Map()); + const videoGenerationAbortRef = useRef(new Map()); const canvasGenKeepaliveRestoredRef = useRef(false); const setTextGenerationStatus = (nodeId: string, state: CanvasTextGenerationState) => { @@ -80,6 +82,15 @@ export function useCanvasGeneration(params: UseCanvasGenerationParams) { setVideoGenerationState((current) => ({ ...current, [nodeId]: state })); }; + const abortAllGenerationPollers = useCallback(() => { + textGenerationAbortControllersRef.current.forEach((c) => c.abort()); + textGenerationAbortControllersRef.current.clear(); + imageGenerationAbortRef.current.forEach((ref) => { ref.current = true; }); + imageGenerationAbortRef.current.clear(); + videoGenerationAbortRef.current.forEach((ref) => { ref.current = true; }); + videoGenerationAbortRef.current.clear(); + }, []); + // Toast auto-dismiss useEffect(() => { if (!generationToast) return undefined; @@ -103,11 +114,14 @@ export function useCanvasGeneration(params: UseCanvasGenerationParams) { } if (entry.nodeKind === "image") { imageGenerationInFlightRef.current.add(entry.nodeId); + const abortRef = { current: false }; + imageGenerationAbortRef.current.set(entry.nodeId, abortRef); setImageGenerationStatus(entry.nodeId, { status: "running", message: "正在恢复图片生成", progress: 20 }); void waitForImageTaskResult(entry.taskId, (status) => { const progress = Math.max(18, Math.min(status.status === "completed" ? 100 : 96, Math.trunc(status.progress || 0))); setImageGenerationStatus(entry.nodeId, { status: "running", message: "图片生成中", progress }); - }).then(async (outputUrl) => { + }, abortRef).then(async (outputUrl) => { + if (abortRef.current || !outputUrl) return; removeCanvasGenKeepalive(entry.taskId); setImageGenerationStatus(entry.nodeId, { status: "success", message: "生成完成", progress: 100 }); const ref = createCanvasAssetRefFromGeneratedResult({ @@ -128,18 +142,23 @@ export function useCanvasGeneration(params: UseCanvasGenerationParams) { )); } }).catch(() => { + if (abortRef.current) return; removeCanvasGenKeepalive(entry.taskId); setImageGenerationStatus(entry.nodeId, { status: "error", message: "图片生成失败" }); }).finally(() => { imageGenerationInFlightRef.current.delete(entry.nodeId); + imageGenerationAbortRef.current.delete(entry.nodeId); }); } else if (entry.nodeKind === "video") { videoGenerationInFlightRef.current.add(entry.nodeId); + const abortRef = { current: false }; + videoGenerationAbortRef.current.set(entry.nodeId, abortRef); setVideoGenerationStatus(entry.nodeId, { status: "running", message: "正在恢复视频生成", progress: 20 }); void waitForVideoTaskResult(entry.taskId, (status) => { const progress = Math.max(18, Math.min(status.status === "completed" ? 100 : 96, Math.trunc(status.progress || 0))); setVideoGenerationStatus(entry.nodeId, { status: "running", message: "视频生成中", progress }); - }).then(async (outputUrl) => { + }, abortRef).then(async (outputUrl) => { + if (abortRef.current || !outputUrl) return; removeCanvasGenKeepalive(entry.taskId); setVideoGenerationStatus(entry.nodeId, { status: "success", message: "生成完成", progress: 100 }); const ref = createCanvasAssetRefFromGeneratedResult({ @@ -160,18 +179,19 @@ export function useCanvasGeneration(params: UseCanvasGenerationParams) { )); } }).catch(() => { + if (abortRef.current) return; removeCanvasGenKeepalive(entry.taskId); setVideoGenerationStatus(entry.nodeId, { status: "error", message: "视频生成失败" }); }).finally(() => { videoGenerationInFlightRef.current.delete(entry.nodeId); + videoGenerationAbortRef.current.delete(entry.nodeId); }); } } }; const resetGenerationState = () => { - textGenerationAbortControllersRef.current.forEach((c) => c.abort()); - textGenerationAbortControllersRef.current.clear(); + abortAllGenerationPollers(); textGenerationInFlightRef.current.clear(); imageGenerationInFlightRef.current.clear(); videoGenerationInFlightRef.current.clear(); @@ -180,11 +200,18 @@ export function useCanvasGeneration(params: UseCanvasGenerationParams) { setVideoGenerationState({}); }; + // Stop all in-flight front-end polling/setState when the canvas unmounts (route change). + // Keepalive records are intentionally preserved so restoreKeepaliveTasks can resume on return. + useEffect(() => { + return () => { + abortAllGenerationPollers(); + }; + }, [abortAllGenerationPollers]); + useEffect(() => { const handlePageHide = () => { cancelCanvasGenKeepaliveOnUnload(); - textGenerationAbortControllersRef.current.forEach((controller) => controller.abort()); - textGenerationAbortControllersRef.current.clear(); + abortAllGenerationPollers(); textGenerationInFlightRef.current.clear(); imageGenerationInFlightRef.current.clear(); videoGenerationInFlightRef.current.clear(); @@ -202,7 +229,7 @@ export function useCanvasGeneration(params: UseCanvasGenerationParams) { window.removeEventListener("pagehide", handlePageHide); window.removeEventListener("online", handleOnline); }; - }, []); + }, [abortAllGenerationPollers]); return { textGenerationState, @@ -214,6 +241,8 @@ export function useCanvasGeneration(params: UseCanvasGenerationParams) { videoGenerationInFlightRef, textGenerationInFlightRef, textGenerationAbortControllersRef, + imageGenerationAbortRef, + videoGenerationAbortRef, canvasGenKeepaliveRestoredRef, setTextGenerationStatus, setImageGenerationStatus, diff --git a/src/features/workbench/WorkbenchPage.tsx b/src/features/workbench/WorkbenchPage.tsx index ddf97aa..839a8c5 100644 --- a/src/features/workbench/WorkbenchPage.tsx +++ b/src/features/workbench/WorkbenchPage.tsx @@ -40,6 +40,7 @@ import { useGenerationTasks } from "../../hooks/useGenerationTasks"; import { conversationClient, type ConversationSummary } from "../../api/conversationClient"; import { modelCapabilitiesClient } from "../../api/modelCapabilitiesClient"; +import { publicPricingClient, type PublicModelPrice } from "../../api/publicPricingClient"; import type { CreatePreviewTaskInput } from "../../api/webGenerationGateway"; import type { WebProjectSummary } from "../../types"; import { @@ -58,6 +59,8 @@ import { import { translateTaskError } from "../../utils/translateTaskError"; import { buildLocalTimeoutMessage, + FALLBACK_TEXT_TOKEN_CREDIT_RATE, + formatTextTokenCreditRule, getTaskTimeoutPolicy, isTaskLocallyTimedOut, } from "../../utils/taskLifecycle"; @@ -69,7 +72,12 @@ import { import { isViduModel } from "../../utils/viduRouting"; import { isPixverseModel } from "../../utils/pixverseRouting"; import { resolveVideoRequestModel } from "../../utils/resolveVideoModel"; -import { calculateEnterpriseVideoCredits, ENTERPRISE_DEFAULT_VIDEO_MODEL } from "../../utils/enterpriseVideoPolicy"; +import { + calculateEnterpriseVideoCredits, + ENTERPRISE_DEFAULT_VIDEO_MODEL, + type EnterpriseVideoPricingConfig, +} from "../../utils/enterpriseVideoPolicy"; +import { resolveTextTokenCreditRate } from "../../utils/modelPricing"; import { getImageQualityOptionsForContext, getDefaultImageQuality, @@ -404,9 +412,32 @@ function WorkbenchPage({ const [videoQuality, setVideoQuality] = useState(() => getDefaultVideoQuality(VIDEO_MODEL_OPTIONS[0].value)); const [chatModel, setChatModel] = useState(CHAT_MODEL_OPTIONS[0].value); + const [modelPrices, setModelPrices] = useState([]); + const [enterpriseVideoPricing, setEnterpriseVideoPricing] = useState(null); const [thinkingSpeed, setThinkingSpeed] = useState(THINKING_SPEED_OPTIONS[0].value); const [thinkingDepth, setThinkingDepth] = useState(THINKING_DEPTH_OPTIONS[0].value); + useEffect(() => { + let cancelled = false; + + publicPricingClient + .getPricing() + .then((pricing) => { + if (cancelled) return; + setModelPrices(pricing.modelPrices); + setEnterpriseVideoPricing(pricing.enterpriseVideoPricing); + }) + .catch(() => { + if (cancelled) return; + setModelPrices([]); + setEnterpriseVideoPricing(null); + }); + + return () => { + cancelled = true; + }; + }, []); + useEffect(() => { let cancelled = false; @@ -525,6 +556,10 @@ function WorkbenchPage({ const videoQualityLabel = getVideoQualityLabel(videoModel, videoQuality); const imageSettingsSummary = `${imageRatio} / ${imageQuality}`; + const selectedChatTokenRate = useMemo( + () => resolveTextTokenCreditRate(modelPrices, chatModel) || FALLBACK_TEXT_TOKEN_CREDIT_RATE, + [chatModel, modelPrices], + ); const billingEstimate = useMemo(() => { if (activeMode === "image") { return { @@ -541,7 +576,7 @@ function WorkbenchPage({ durationSeconds, muted: false, hasReferenceVideo: referenceItems.some((item) => item.kind === "video"), - }); + }, enterpriseVideoPricing || undefined); return { label: `预计 ${formatCreditValue(credits)} 积分`, title: `${activeModel},${videoQualityLabel},${durationSeconds} 秒,预计 ${formatCreditValue(credits)} 积分`, @@ -553,16 +588,20 @@ function WorkbenchPage({ }; } } + const textBillingPrefix = + selectedChatTokenRate.source === "server" ? "文本计费" : "服务端价格暂不可用,按默认预估"; return { label: "按 Token 结算", - title: "文本对话按输入、输出 Token 实际用量结算,完成后显示本次积分", + title: `${textBillingPrefix}:${activeModel},${formatTextTokenCreditRule(selectedChatTokenRate)}`, }; }, [ activeMode, activeModel, activeModelValue, imageSettingsSummary, + enterpriseVideoPricing, referenceItems, + selectedChatTokenRate, videoDuration, videoQuality, videoQualityLabel, diff --git a/src/utils/enterpriseVideoPolicy.test.ts b/src/utils/enterpriseVideoPolicy.test.ts index 31e34e9..0123257 100644 --- a/src/utils/enterpriseVideoPolicy.test.ts +++ b/src/utils/enterpriseVideoPolicy.test.ts @@ -2,6 +2,7 @@ import { describe, expect, it } from "../test/testHarness"; import { calculateEnterpriseVideoCredits, + type EnterpriseVideoPricingConfig, getEnterpriseVideoCreditRate, normalizeEnterpriseResolution, } from "./enterpriseVideoPolicy"; @@ -45,4 +46,40 @@ describe("enterpriseVideoPolicy", () => { }), ).toBe(1); }); + + it("uses server-provided pricing config before fallback pricing", () => { + const serverPricing: EnterpriseVideoPricingConfig = { + creditsPerCny: 100, + defaultResolution: "1080P", + rules: [ + { + id: "happyhorse-server", + modelIncludes: ["happyhorse"], + rates: { "720P": 2, "1080P": 3 }, + }, + ], + }; + + expect( + getEnterpriseVideoCreditRate( + { + model: "happyhorse-1.0", + resolution: "1080P", + durationSeconds: 5, + }, + serverPricing, + ), + ).toBe(3); + + expect( + calculateEnterpriseVideoCredits( + { + model: "happyhorse-1.0", + resolution: "1080P", + durationSeconds: 5, + }, + serverPricing, + ), + ).toBe(1500); + }); }); diff --git a/src/utils/enterpriseVideoPolicy.ts b/src/utils/enterpriseVideoPolicy.ts index 359aa31..074a188 100644 --- a/src/utils/enterpriseVideoPolicy.ts +++ b/src/utils/enterpriseVideoPolicy.ts @@ -50,50 +50,119 @@ export interface EnterpriseVideoPricingInput { hasReferenceVideo?: boolean; } +export interface EnterpriseVideoPricingRule { + id: string; + modelIncludes: string[]; + when?: { + muted?: boolean; + hasReferenceVideo?: boolean; + }; + rates: Record; +} + +export interface EnterpriseVideoPricingConfig { + currency?: string; + creditsPerCny?: number; + billingUnit?: "per_second" | string; + defaultResolution?: string; + resolutions?: string[]; + rules: EnterpriseVideoPricingRule[]; +} + +export const FALLBACK_ENTERPRISE_VIDEO_PRICING_CONFIG: EnterpriseVideoPricingConfig = { + currency: "CNY", + creditsPerCny: CREDITS_PER_CNY, + billingUnit: "per_second", + defaultResolution: ENTERPRISE_DEFAULT_VIDEO_RESOLUTION, + resolutions: ["720P", "1080P"], + rules: [ + { + id: "happyhorse", + modelIncludes: ["happyhorse"], + rates: { "720P": 0.72, "1080P": 1.28 }, + }, + { + id: "wanxiang-i2v", + modelIncludes: ["wan2.7-i2v", "wanxiang"], + rates: { "720P": 0.6, "1080P": 1 }, + }, + { + id: "wan-animate-s2v", + modelIncludes: ["animate-mix", "s2v"], + rates: { "720P": 0.6, "1080P": 1 }, + }, + { + id: "kling-muted-reference", + modelIncludes: ["kling"], + when: { muted: true, hasReferenceVideo: true }, + rates: { "720P": 0.9, "1080P": 1.2 }, + }, + { + id: "kling-muted", + modelIncludes: ["kling"], + when: { muted: true, hasReferenceVideo: false }, + rates: { "720P": 0.6, "1080P": 0.8 }, + }, + { + id: "kling-default", + modelIncludes: ["kling"], + rates: { "720P": 0.9, "1080P": 1.2 }, + }, + { + id: "vidu", + modelIncludes: ["vidu"], + rates: { "720P": 0.6, "1080P": 1 }, + }, + { + id: "pixverse", + modelIncludes: ["pixverse"], + rates: { "720P": 0.6, "1080P": 1 }, + }, + ], +}; + export function normalizeEnterpriseResolution(value: string): "720P" | "1080P" { return String(value || "").toUpperCase() === "720P" ? "720P" : "1080P"; } -export function getEnterpriseVideoCreditRate(input: EnterpriseVideoPricingInput): number { +function enterpriseVideoPricingRuleMatches( + rule: EnterpriseVideoPricingRule, + input: EnterpriseVideoPricingInput, + model: string, +): boolean { + if (!rule.modelIncludes.some((pattern) => model.includes(String(pattern || "").toLowerCase()))) return false; + if (!rule.when) return true; + if ("muted" in rule.when && Boolean(input.muted) !== rule.when.muted) return false; + if ("hasReferenceVideo" in rule.when && Boolean(input.hasReferenceVideo) !== rule.when.hasReferenceVideo) { + return false; + } + return true; +} + +export function getEnterpriseVideoCreditRate( + input: EnterpriseVideoPricingInput, + config: EnterpriseVideoPricingConfig = FALLBACK_ENTERPRISE_VIDEO_PRICING_CONFIG, +): number { const resolution = normalizeEnterpriseResolution(input.resolution); const model = String(input.model || "").toLowerCase(); + const fallbackResolution = normalizeEnterpriseResolution( + config.defaultResolution || ENTERPRISE_DEFAULT_VIDEO_RESOLUTION, + ); + const rule = config.rules.find((candidate) => enterpriseVideoPricingRuleMatches(candidate, input, model)); - if (model.includes("happyhorse")) { - return resolution === "720P" ? 0.72 : 1.28; - } - - if (model.includes("wan2.7-i2v") || model.includes("wanxiang")) { - return resolution === "720P" ? 0.6 : 1; - } - - if (model.includes("animate-mix")) { - return resolution === "720P" ? 0.6 : 1; - } - - if (model.includes("s2v")) { - return resolution === "720P" ? 0.6 : 1; - } - - if (model.includes("vidu")) { - return resolution === "720P" ? 0.6 : 1.0; - } - - if (model.includes("pixverse")) { - return resolution === "720P" ? 0.6 : 1.0; - } - - if (model.includes("kling")) { - if (input.muted) { - if (input.hasReferenceVideo) return resolution === "720P" ? 0.9 : 1.2; - return resolution === "720P" ? 0.6 : 0.8; - } - return resolution === "720P" ? 0.9 : 1.2; + if (rule) { + const rate = rule.rates[resolution] ?? rule.rates[fallbackResolution]; + if (Number.isFinite(rate) && rate >= 0) return rate; } throw new Error(`Unsupported enterprise video model: ${input.model}`); } -export function calculateEnterpriseVideoCredits(input: EnterpriseVideoPricingInput): number { +export function calculateEnterpriseVideoCredits( + input: EnterpriseVideoPricingInput, + config: EnterpriseVideoPricingConfig = FALLBACK_ENTERPRISE_VIDEO_PRICING_CONFIG, +): number { const duration = Math.max(1, Math.ceil(Number(input.durationSeconds) || 1)); - return Number((getEnterpriseVideoCreditRate(input) * duration * CREDITS_PER_CNY).toFixed(2)); + const creditsPerCny = Number(config.creditsPerCny || CREDITS_PER_CNY); + return Number((getEnterpriseVideoCreditRate(input, config) * duration * creditsPerCny).toFixed(2)); } diff --git a/src/utils/modelPricing.test.ts b/src/utils/modelPricing.test.ts new file mode 100644 index 0000000..e235523 --- /dev/null +++ b/src/utils/modelPricing.test.ts @@ -0,0 +1,60 @@ +import { describe, expect, it } from "../test/testHarness"; + +import { + millsPerThousandTokensToCreditsPerMillion, + modelPriceToTextTokenCreditRate, + resolveTextTokenCreditRate, +} from "./modelPricing"; + +describe("modelPricing", () => { + it("converts backend mills per thousand tokens to credits per million tokens", () => { + expect(millsPerThousandTokensToCreditsPerMillion(27)).toBe(2_700); + expect(millsPerThousandTokensToCreditsPerMillion(108)).toBe(10_800); + }); + + it("converts a token model price row to a text token credit rate", () => { + expect( + modelPriceToTextTokenCreditRate({ + modelKey: "gpt-4o", + inputPriceMills: 27, + outputPriceMills: 108, + flatPriceMills: null, + currency: "CNY", + enabled: true, + }), + ).toEqual({ + inputCreditsPerMillion: 2_700, + outputCreditsPerMillion: 10_800, + source: "server", + modelKey: "gpt-4o", + }); + }); + + it("resolves token pricing by exact or fuzzy model key without accepting flat prices", () => { + const prices = [ + { + modelKey: "gemini-3-pro-image", + inputPriceMills: null, + outputPriceMills: null, + flatPriceMills: 200, + currency: "CNY", + enabled: true, + }, + { + modelKey: "gemini-3.1-pro", + inputPriceMills: 12, + outputPriceMills: 48, + flatPriceMills: null, + currency: "CNY", + enabled: true, + }, + ]; + + expect(resolveTextTokenCreditRate(prices, "gemini")).toEqual({ + inputCreditsPerMillion: 1_200, + outputCreditsPerMillion: 4_800, + source: "server", + modelKey: "gemini-3.1-pro", + }); + }); +}); diff --git a/src/utils/modelPricing.ts b/src/utils/modelPricing.ts new file mode 100644 index 0000000..924c5ed --- /dev/null +++ b/src/utils/modelPricing.ts @@ -0,0 +1,104 @@ +import type { PublicModelPrice } from "../api/publicPricingClient"; +import type { TextTokenCreditRate } from "./taskLifecycle"; + +const TOKENS_PER_MILLION = 1_000_000; +const BACKEND_TOKEN_PRICE_UNIT = 1_000; +const CREDITS_PER_CNY = 100; +const MILLS_PER_CNY = 1_000; +const CREDITS_PER_MILL = CREDITS_PER_CNY / MILLS_PER_CNY; + +function isUsablePrice(value: number | null | undefined): value is number { + return typeof value === "number" && Number.isFinite(value) && value >= 0; +} + +function normalizeModelKey(value: string): string { + return value.trim().toLowerCase(); +} + +function compactModelKey(value: string): string { + return normalizeModelKey(value).replace(/[^a-z0-9]+/g, ""); +} + +function addCandidate( + candidates: PublicModelPrice[], + seen: Set, + price: PublicModelPrice, +): void { + const key = normalizeModelKey(price.modelKey); + if (seen.has(key)) return; + seen.add(key); + candidates.push(price); +} + +export function millsPerThousandTokensToCreditsPerMillion( + priceMills: number, +): number { + if (!isUsablePrice(priceMills)) return 0; + return ( + priceMills * + (TOKENS_PER_MILLION / BACKEND_TOKEN_PRICE_UNIT) * + CREDITS_PER_MILL + ); +} + +export function modelPriceToTextTokenCreditRate( + price: PublicModelPrice, +): TextTokenCreditRate | null { + if ( + !isUsablePrice(price.inputPriceMills) || + !isUsablePrice(price.outputPriceMills) + ) + return null; + + return { + inputCreditsPerMillion: millsPerThousandTokensToCreditsPerMillion( + price.inputPriceMills, + ), + outputCreditsPerMillion: millsPerThousandTokensToCreditsPerMillion( + price.outputPriceMills, + ), + source: "server", + modelKey: price.modelKey, + }; +} + +export function resolveTextTokenCreditRate( + prices: PublicModelPrice[], + modelKey: string | null | undefined, +): TextTokenCreditRate | null { + const normalizedTarget = normalizeModelKey(modelKey || ""); + if (!normalizedTarget) return null; + + const compactTarget = compactModelKey(normalizedTarget); + const candidates: PublicModelPrice[] = []; + const seen = new Set(); + + for (const price of prices) { + if (normalizeModelKey(price.modelKey) === normalizedTarget) { + addCandidate(candidates, seen, price); + } + } + + for (const price of prices) { + if (compactModelKey(price.modelKey) === compactTarget) { + addCandidate(candidates, seen, price); + } + } + + for (const price of prices) { + const compactPriceKey = compactModelKey(price.modelKey); + if ( + compactPriceKey.includes(compactTarget) || + compactTarget.includes(compactPriceKey) + ) { + addCandidate(candidates, seen, price); + } + } + + for (const price of candidates) { + const rate = modelPriceToTextTokenCreditRate(price); + if (rate) return rate; + } + + return null; +} diff --git a/src/utils/taskLifecycle.test.ts b/src/utils/taskLifecycle.test.ts index 6189d79..ab5c4d5 100644 --- a/src/utils/taskLifecycle.test.ts +++ b/src/utils/taskLifecycle.test.ts @@ -4,12 +4,13 @@ import { TEXT_INPUT_CREDITS_PER_MILLION, TEXT_OUTPUT_CREDITS_PER_MILLION, estimateTextTokenCredits, + formatTextTokenCreditRule, getTaskTimeoutPolicy, isTaskLocallyTimedOut, } from "./taskLifecycle"; describe("taskLifecycle", () => { - it("keeps text token billing at 1 CNY to 100 credits", () => { + it("keeps fallback text token billing at 1 CNY to 100 credits", () => { expect(TEXT_INPUT_CREDITS_PER_MILLION).toBe(200); expect(TEXT_OUTPUT_CREDITS_PER_MILLION).toBe(500); expect( @@ -20,6 +21,23 @@ describe("taskLifecycle", () => { ).toBe(700); }); + it("estimates text billing from dynamic server pricing rates", () => { + expect( + estimateTextTokenCredits( + { + promptTokens: 1_000_000, + completionTokens: 1_000_000, + }, + { + inputCreditsPerMillion: 2_700, + outputCreditsPerMillion: 10_800, + source: "server", + modelKey: "gpt-4o", + }, + ), + ).toBe(13_500); + }); + it("ignores negative token counts when estimating text billing", () => { expect( estimateTextTokenCredits({ @@ -29,6 +47,17 @@ describe("taskLifecycle", () => { ).toBe(250); }); + it("formats text billing rules from the selected rate", () => { + expect( + formatTextTokenCreditRule({ + inputCreditsPerMillion: 2_700, + outputCreditsPerMillion: 10_800, + }), + ).toBe( + "输入 Token 每百万 2,700 积分,输出 Token 每百万 10,800 积分,实际以服务端结算为准。", + ); + }); + it("marks unstarted tasks locally timed out after submit timeout", () => { const policy = getTaskTimeoutPolicy({ kind: "image" }); diff --git a/src/utils/taskLifecycle.ts b/src/utils/taskLifecycle.ts index 757c4b2..6db0311 100644 --- a/src/utils/taskLifecycle.ts +++ b/src/utils/taskLifecycle.ts @@ -32,11 +32,24 @@ export interface TextTokenUsage { totalTokens?: number; } +export interface TextTokenCreditRate { + inputCreditsPerMillion: number; + outputCreditsPerMillion: number; + source?: "server" | "fallback"; + modelKey?: string; +} + const CREDITS_PER_CNY = 100; export const TEXT_INPUT_CREDITS_PER_MILLION = 2 * CREDITS_PER_CNY; export const TEXT_OUTPUT_CREDITS_PER_MILLION = 5 * CREDITS_PER_CNY; +export const FALLBACK_TEXT_TOKEN_CREDIT_RATE: TextTokenCreditRate = { + inputCreditsPerMillion: TEXT_INPUT_CREDITS_PER_MILLION, + outputCreditsPerMillion: TEXT_OUTPUT_CREDITS_PER_MILLION, + source: "fallback", +}; + const IMAGE_TIMEOUT_POLICY: TaskTimeoutPolicy = { submitTimeoutMs: 90_000, noProgressTimeoutMs: 120_000, @@ -145,18 +158,42 @@ export function getRefundHint(status: TaskRefundStatus): string { } } -export function estimateTextTokenCredits(usage: TextTokenUsage): number { - const promptTokens = Math.max(0, Number(usage.promptTokens || 0)); - const completionTokens = Math.max(0, Number(usage.completionTokens || 0)); - return (promptTokens / 1_000_000) * TEXT_INPUT_CREDITS_PER_MILLION + - (completionTokens / 1_000_000) * TEXT_OUTPUT_CREDITS_PER_MILLION; +function sanitizeCreditRate(value: number): number { + return Number.isFinite(value) && value >= 0 ? value : 0; } -export function formatTextTokenUsage(usage?: TextTokenUsage | null): string { - const rule = "文本计费规则:输入 Token 每百万 200 积分,输出 Token 每百万 500 积分,实际以服务端结算为准。"; +function formatCreditRate(value: number): string { + const safeValue = sanitizeCreditRate(value); + if (safeValue >= 100) return Math.round(safeValue).toLocaleString("zh-CN"); + return Number(safeValue.toFixed(4)).toString(); +} + +export function formatTextTokenCreditRule( + rate: TextTokenCreditRate = FALLBACK_TEXT_TOKEN_CREDIT_RATE, +): string { + return `输入 Token 每百万 ${formatCreditRate(rate.inputCreditsPerMillion)} 积分,输出 Token 每百万 ${formatCreditRate(rate.outputCreditsPerMillion)} 积分,实际以服务端结算为准。`; +} + +export function estimateTextTokenCredits( + usage: TextTokenUsage, + rate: TextTokenCreditRate = FALLBACK_TEXT_TOKEN_CREDIT_RATE, +): number { + const promptTokens = Math.max(0, Number(usage.promptTokens || 0)); + const completionTokens = Math.max(0, Number(usage.completionTokens || 0)); + return ( + (promptTokens / 1_000_000) * sanitizeCreditRate(rate.inputCreditsPerMillion) + + (completionTokens / 1_000_000) * sanitizeCreditRate(rate.outputCreditsPerMillion) + ); +} + +export function formatTextTokenUsage( + usage?: TextTokenUsage | null, + rate: TextTokenCreditRate = FALLBACK_TEXT_TOKEN_CREDIT_RATE, +): string { + const rule = `文本计费规则:${formatTextTokenCreditRule(rate)}`; if (!usage) return rule; const promptTokens = Math.max(0, Number(usage.promptTokens || 0)); const completionTokens = Math.max(0, Number(usage.completionTokens || 0)); - const estimatedCredits = estimateTextTokenCredits({ promptTokens, completionTokens }); + const estimatedCredits = estimateTextTokenCredits({ promptTokens, completionTokens }, rate); return `本次 Token:输入 ${promptTokens},输出 ${completionTokens},预估 ${estimatedCredits.toFixed(4)} 积分。\n${rule}`; }