Kotonia
ログイン今すぐ始める

Kotonia Articles

LTX-2 22B 使用 fp8_cast 将峰值显存降低 40% — optimum-quanto 是个陷阱

记录 LTX-2.3 22B 的量化尝试。optimum-quanto 与 LTX-2 transformer 存在兼容性问题无法运行,改用 LTX-2 原生的 QuantizationPolicy.fp8_cast() 后,峰值显存从 40 GiB 压缩至 24 GiB(cold-start, 768×512)。包含 3 种分辨率的基准测试以及 cold-start / persistent 模式的选择判断。

作者 3分钟阅读
#LTX-2#量化#FP8#扩散模型#GPU
其他语言日语

前言

LTX-2.3 是 Lightricks 推出的带音频视频生成模型,使用 A2V(Audio-to-Video)模式可以从 1 张图片 + 音频 + 提示词 一体生成唇形同步、表情、头部和头发的运动。与 MuseTalk 这类“仅唇形同步”的模型不同,它能驱动整个场景,因此作为“演出路径”非常强大。

不过,基础检查点有 22B 参数、43 GB,以 bf16 常驻时,transformer × 2 stage 会导致 空闲时约 86 GiB。这几乎占满了 RTX PRO 6000 Blackwell 的 96 GiB,导致同机的 TTS / Ditto-TalkingHead / Qwen3-TTS-vLLM 被挤出。

因此尝试了量化,结果 使用 LTX-2 原生的 fp8_cast 将峰值显存从 40 GiB 压缩到 24 GiB(A2V cold-start, 768×512 / 97f)。另一方面,optimum-quanto 的 int8/fp8 与 LTX-2 transformer 存在兼容性问题,无法运行。本文记录了调试和判断过程。


环境

  • GPU: NVIDIA RTX PRO 6000 Blackwell Max-Q Workstation Edition (96 GiB)
  • PyTorch: 2.9.1 + CUDA 12.8
  • 模型: LTX-2.3 22B-dev (base) + 22B-distilled-lora-384 (stage_2) + Gemma-3-12B text encoder (bnb 4bit)
  • 运行形态: 使用 scripts/persistent_a2v_server.py --cold-start 运行 A2V。每次请求执行 build → run → free,空闲时 0 GiB。

采用 cold-start 的原因写在另一篇文章中(或另行说明)。简而言之,以对话为主,A2V 偶尔调用,必须与 TTS / Ditto 共存


候选方案有 4 种

查看 LTX-2 的代码库,实际上有两条量化路径:

1. LTX-2 原生:QuantizationPolicy

packages/ltx-core/src/ltx_core/quantization/policy.py:

@dataclass(frozen=True)
class QuantizationPolicy:
    sd_ops: SDOps | None = None              # 加载 state dict 时的权重转换
    module_ops: tuple[ModuleOps, ...] = ()   # 加载后的模块重写

    @classmethod
    def fp8_cast(cls) -> "QuantizationPolicy":
        """以 float8_e4m3fn 加载权重,前向时向上转换为 bf16"""
        return cls(
            sd_ops=TRANSFORMER_LINEAR_DOWNCAST_MAP,
            module_ops=(UPCAST_DURING_INFERENCE,),
        )

    @classmethod
    def fp8_scaled_mm(cls) -> "QuantizationPolicy":
        """FP8 scaled MM(需要 tensorrt_llm)"""

fp8_cast 的实际实现是 Fp8CastLinear

class Fp8CastLinear(torch.nn.Linear):
    def forward(self, input):
        w_up = _upcast_and_round(self.weight, input.dtype, ...)
        b_up = _upcast_and_round(self.bias, input.dtype, ...) if self.bias is not None else None
        return torch.nn.functional.linear(input, w_up, b_up)

这是一种通过 __class__ 重新赋值来改写实例的类型转换模式。权重以 fp8 保存,每次前向时转换为 bf16 再进行 matmul。fp8 → bf16 的转换开销存在,但在 Blackwell 上几乎可以忽略不计。

2. optimum-quanto

LTX-2 trainer 包 (packages/ltx-trainer) 包含使用 optimum-quanto 的通用量化,可以选择 int8-quanto / int4-quanto / fp8-quanto

def quantize_model(model, precision, ...):
    if hasattr(model, "transformer_blocks"):
        _quantize_blockwise(model, ...)   # 逐个 block 加载到 GPU 量化 → freeze → CPU
    else:
        quantize(model, weights=..., exclude=EXCLUDE_PATTERNS)
        freeze(model)
    return model

看起来只要在 _build_transformer() 之后立即插入即可。

4 种候选方案矩阵

模式路径预期
fp8-castLTX-2 原生,通过 sd_ops 加载 float8_e4m3fn~减半,速度几乎不变
fp8-scaled-mmLTX-2 原生,需要 tensorrt_llm更快
int8-quantooptimum-quanto,构建后量化~减半,速度 ±
fp8-quanto同上,fp8 版本可能利用 Blackwell 的原生 FP8

fp8-scaled-mm 因环境未安装 tensorrt_llm 而跳过。实现了其余 3 种。


首先尝试 int8-quanto 并踩坑

实现很直接:

from ltx_trainer.quantization import quantize_model

transformer_1 = self.pipeline.stage_1._build_transformer()
transformer_1 = quantize_model(transformer_1, "int8-quanto", device=self.device)
self.transformer_stage_1 = _freeze(transformer_1)

服务器启动通过。空闲显存也符合预期:

[load] stage_1 transformer (no distilled LoRA)
[quantize] stage_1 -> int8-quanto
[quantize] stage_1 done in 0.71s
[cuda] after stage_1 transformer: allocated=31.28GiB ...
[load] stage_2 transformer (with distilled LoRA)
[quantize] stage_2 -> int8-quanto
[quantize] stage_2 done in 0.52s
[cuda] after stage_2 transformer: allocated=49.40GiB ...
[server] A2V listening on http://127.0.0.1:8892

常驻 51.7 GiB(相比 bf16 估计的 86 GiB 减少了 40%)。看起来不错。

然而第一次 /generate 请求时:

[timing] prompt_encode=0.75s
[timing] audio_encode=0.39s
  0%|          | 0/30 [00:00<?, ?it/s]
[http] POST /generate 400

在 step 0/30 处抛出异常。提取错误信息:

{"error": "linear(): argument 'weight' (position 2) must be Tensor, not NoneType"}

说明某处调用了 torch.nn.functional.linear(input, weight=None, bias=None)。也就是说,quanto 的 freeze() 之后,某个 Linear 的 self.weight 被当作 None 引用了

为什么 weight=None

两个粗略的假设:

  1. LTX-2 的 Linear 依赖于 __class__ 重新赋值。与 Fp8CastLinear 的做法相同,为了通过 class 级别重写 forward,需要保持实例的状态。quanto 通过 quantize()freeze()nn.Linear 替换为自定义的 QLinear 包装器,在这个过程中可能破坏了 weight 属性的引用。
  2. EXCLUDE_PATTERNS 在 blockwise 路径中未生效。LTX-trainer 的 _quantize_blockwise 逐个取出 transformer_blocks 并调用 quantize(block, exclude=EXCLUDE_PATTERNS)。但 EXCLUDE_PATTERNS 的内容是 patchify_proj*adaln*time_proj 等基于模型全局路径的 glob,从 block 内部看到的相对路径不匹配。本应排除的层可能被量化了。

无论哪种情况,要彻底修复都需要阅读 quanto 的包装器实现 + LTX-2 transformer 所有层的 forward 路径,成本不划算。决定不深究,改用 LTX-2 原生的 fp8_cast


改用 fp8_cast

实现修改仅 3 行:

# 在构建 pipeline 时传入量化策略即可
pipeline_quantization = None
if transformer_quantization == "fp8-cast":
    from ltx_core.quantization import QuantizationPolicy
    pipeline_quantization = QuantizationPolicy.fp8_cast()

self.pipeline = A2VidPipelineTwoStage(
    ...,
    quantization=pipeline_quantization,
    ...
)

fp8_cast加载权重阶段就向下转换为 fp8sd_ops 是 state_dict 读取时的钩子,因此 43 GB 的 safetensors 在流式加载过程中就被转换为 fp8。与 quanto 先将 bf16 完整加载到内存再量化不同,峰值显存不会膨胀,这一点很令人满意。

启动后:

[load] A2VidPipelineTwoStage builders (pipeline_quantization=QuantizationPolicy(sd_ops=...fp8_cast...))
...
[cuda] after stage_1 transformer: allocated=31.30GiB reserved=35.18GiB
[cuda] after stage_2 transformer: allocated=49.43GiB reserved=53.64GiB
[server] A2V listening on http://127.0.0.1:8892

常驻 allocated 51.7 GiB 与 int8-quanto 相当,但 reserved 仅为 53.6 GiB,远低于 int8-quanto 的 70.9 GiB。reserved 小意味着激活内存的余量更大。

最关键的是 /generate

{
  "elapsed_seconds": 39.367,
  "peak_vram_gib": 57.918,
  "width": 768, "height": 512, "num_frames": 97
}

正常运行。至此走上了正轨。


基准测试

固定条件下对 persistent + fp8-cast 进行 3 种分辨率 × 3 次测量:

  • 图片: 1024×512 人像
  • 音频: 使用 Irodori-TTS 生成的 9.08 秒日语样本
  • 提示词: "A young woman speaks calmly to the camera in a softly lit room."
  • num_frames: 97 (= 4.04s @ 24fps)
  • seed: 固定为 42
分辨率平均耗时 (s)峰值显存 (GiB)
768×512 / 97f39.8457.92
1024×768 / 97f66.7159.06
1280×768 / 97f84.0258.30

值得注意的点:

  • 3 次运行的方差几乎为零(seed 固定,输出 mp4 字节一致)
  • 峰值显存几乎不随分辨率变化(57.9 ~ 59.1 GiB)。常驻权重占主导,激活内存仅约 7 GiB
  • 1280×768 在 persistent 模式下稳定运行。bf16 persistent(峰值约 91 GiB)下实际上无法运行的分辨率,现在可以了

cold-start 模式下同样胜出

生产环境使用 cold-start 运行(A2V 每分钟调用 1-2 次,必须与 TTS 共存)。fp8_cast 策略在构建 pipeline 时通过 sd_ops 应用,因此对 cold-start 的每次请求构建也同样有效。

cold-start + fp8-cast 单次(768×512 / 97f)的实际测量:

{
  "elapsed_seconds": 88.775,
  "peak_vram_gib": 23.901
}
bf16 cold-startfp8-cast cold-start
单次请求时间~60-90s88.8s(磁盘 I/O 主导,水平相当)
峰值显存~40 GiB23.9 GiB(减少约 40%)
空闲0 GiB0 GiB
共存(TTS+Ditto+Qwen3+MuseTalk)可行绰绰有余(峰值约 30 GiB)

速度受磁盘 I/O 主导,因此 fp8 化后没有变化,但峰值显存节省了 16 GiB 非常有效。即使 Qwen3-TTS-vLLM(7 GiB)和 MuseTalk warmup 与 A2V 生成同时运行,也不会 OOM。


如何选择使用模式

使用场景推荐模式依据
以对话为主,A2V 偶尔用于演出cold-start + fp8-cast空闲 0、峰值 24 GiB、与 TTS/Ditto 轻松共存
频繁调用 A2V(批量生成、自动演出)persistent + fp8-cast付出常驻 52 GiB 的代价,换取 40s/req
1024+ 分辨率且注重质量persistent + fp8-cast1280×768 稳定运行(bf16 persistent 无法实现的分辨率)
单 GPU 托管所有服务cold-start + fp8-cast持续常驻会占用 52 GiB,需根据与其他服务的预算分配决定

生产环境的判断是:“以对话为主,因此使用 cold-start + fp8-cast;如果付费用户增加,需要频繁调用 A2V,则切换到 persistent fp8-cast”。是否值得付出空闲 52 GiB 的代价,取决于 ROI。


总结

  • LTX-2 22B 的 bf16 空闲 86 GiB 几乎独占单 GPU。量化几乎是必须的
  • optimum-quanto 与 LTX-2 transformer 不兼容。在 F.linear(weight=None) 处崩溃。推测原因为 __class__ 重新赋值 / EXCLUDE_PATTERNS 的生效方式之一(未深究)
  • LTX-2 原生的 QuantizationPolicy.fp8_cast() 是正确的选择。加载时转换为 fp8,前向时 upcast 为 bf16。实现仅 3 行
  • cold-start + fp8-cast 使峰值从 40 降至 24 GiB,persistent + fp8-cast 使 1280×768 分辨率变得可用
  • LTX-2 还提供了 fp8_scaled_mm(需要 tensorrt_llm),如果有意使用 TRT,也值得尝试

附录:启动命令与复现条件

生产环境的 cold-start + fp8-cast 启动:

PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True nohup uv run python scripts/persistent_a2v_server.py \
  --port 8892 \
  --checkpoint-path models/LTX-2.3/ltx-2.3-22b-dev.safetensors \
  --distilled-lora-path models/loras/ltx-2.3-22b-distilled-lora-384-1.1.safetensors \
  --spatial-upsampler-path models/LTX-2.3/ltx-2.3-spatial-upscaler-x2-1.1.safetensors \
  --gemma-root models/gemma-3-12b-it-qat-q4_0-unquantized \
  --output-dir outputs/a2v_server \
  --transformer-quantization fp8-cast \
  --cold-start \
  > /tmp/ltx_a2v_server.log 2>&1 &

persistent_a2v_server.py 本身是 LTX-2 仓库官方脚本的 A2V 扩展版本。本文的 --transformer-quantization fp8-cast 标志是通过实现补丁添加的。

实现补丁(要点):

# scripts/persistent_a2v_server.py
pipeline_quantization = None
if transformer_quantization in ("fp8-cast", "fp8-scaled-mm"):
    from ltx_core.quantization import QuantizationPolicy  # 延迟导入:避免循环引用
    pipeline_quantization = (
        QuantizationPolicy.fp8_cast()
        if transformer_quantization == "fp8-cast"
        else QuantizationPolicy.fp8_scaled_mm()
    )

self.pipeline = A2VidPipelineTwoStage(
    ...,
    quantization=pipeline_quantization,
    ...,
)

如果在顶层编写 from ltx_core.quantization import QuantizationPolicy,会与 ltx_core.loader 产生循环引用导致崩溃,因此必须延迟导入。

Kotonia 将语音 AI、AI 聊天、图像生成和团队协作整合到一个 AI 工作区中。

试用 Kotonia