LTX-2 22B を fp8_cast で peak VRAM 40% 削減した話 — optimum-quanto は罠だった

#ltx2#quantization#fp8#diffusion#gpu

この記事は zenn.dev

はじめに

LTX-2.3 は Lightricks が出している音声付き動画生成モデルで、A2V(Audio-to-Video)モードを使うと 画像 1 枚 + 音声 + プロンプト からリップシンク・表情・頭部や髪の動きを一体で生成できる。MuseTalk のような「リップシンクだけ」のモデルと違い、シーン全体を動かせるので「演出経路」として強力。

ただし base checkpoint が 22B パラメータ・43 GB あり、bf16 で常駐させると transformer × 2 stageidle ~86 GiB。RTX PRO 6000 Blackwell の 96 GiB をほぼ食い切ってしまい、同居している TTS / Ditto-TalkingHead / Qwen3-TTS-vLLM が押し出される。

そこで量子化を試した結果、LTX-2 native の fp8_cast で peak VRAM を 40 GiB → 24 GiB に圧縮できた(A2V cold-start, 768×512 / 97f)。一方で optimum-quanto の int8/fp8 は LTX-2 transformer と互換性問題で動かなかった。本記事はそのデバッグと判断の記録。


環境

  • GPU: NVIDIA RTX PRO 6000 Blackwell Max-Q Workstation Edition (96 GiB)
  • PyTorch: 2.9.1 + CUDA 12.8
  • モデル: LTX-2.3 22B-dev (base) + 22B-distilled-lora-384 (stage_2) + Gemma-3-12B text encoder (bnb 4bit)
  • 運用形態: A2V を scripts/persistent_a2v_server.py --cold-start で運用。リクエストごとに build → run → free、idle 0 GiB。

cold-start を採用してる理由は別記事に書いた(or 別途)。要は 会話メインで A2V は時々呼ぶ、TTS / Ditto と同居必須 のため。


候補は 4 つ

LTX-2 のコードベースを見ると、量子化経路は実は 2 系統ある:

1. LTX-2 native:QuantizationPolicy

packages/ltx-core/src/ltx_core/quantization/policy.py:

@dataclass(frozen=True)
class QuantizationPolicy:
    sd_ops: SDOps | None = None              # state dict ロード時の重み変換
    module_ops: tuple[ModuleOps, ...] = ()   # ロード後のモジュール書換

    @classmethod
    def fp8_cast(cls) -> "QuantizationPolicy":
        """weight を float8_e4m3fn でロード、forward 時 bf16 にアップキャスト"""
        return cls(
            sd_ops=TRANSFORMER_LINEAR_DOWNCAST_MAP,
            module_ops=(UPCAST_DURING_INFERENCE,),
        )

    @classmethod
    def fp8_scaled_mm(cls) -> "QuantizationPolicy":
        """FP8 scaled MM(tensorrt_llm 必須)"""

fp8_cast の実体は Fp8CastLinear

class Fp8CastLinear(torch.nn.Linear):
    def forward(self, input):
        w_up = _upcast_and_round(self.weight, input.dtype, ...)
        b_up = _upcast_and_round(self.bias, input.dtype, ...) if self.bias is not None else None
        return torch.nn.functional.linear(input, w_up, b_up)

__class__ 再代入でインスタンスを書き換える型変換パターン。重みは fp8 で保持し、forward の都度 bf16 にキャストして matmul する。fp8 → bf16 のキャストコストは付くが、Blackwell ならほぼノイズレベル。

2. optimum-quanto

LTX-2 trainer package (packages/ltx-trainer) には optimum-quanto を使った汎用量子化があり、int8-quanto / int4-quanto / fp8-quanto を選べる:

def quantize_model(model, precision, ...):
    if hasattr(model, "transformer_blocks"):
        _quantize_blockwise(model, ...)   # 1 ブロックずつ GPU に上げて quantize → freeze → CPU
    else:
        quantize(model, weights=..., exclude=EXCLUDE_PATTERNS)
        freeze(model)
    return model

これを _build_transformer() の直後に挟めば良さそうに見える。

4 候補のマトリクス

mode経路期待
fp8-castLTX-2 native、sd_ops で float8_e4m3fn ロード~半減・速度ほぼ同等
fp8-scaled-mmLTX-2 native、tensorrt_llm 必須より高速
int8-quantooptimum-quanto、post-build~半減・速度 ±
fp8-quanto同上、fp8 版Blackwell の native FP8 を踏める可能性

fp8-scaled-mm は tensorrt_llm が入っていない環境なので skip。残り 3 つを実装した。


まず int8-quanto で踏み抜く

実装は素直:

from ltx_trainer.quantization import quantize_model

transformer_1 = self.pipeline.stage_1._build_transformer()
transformer_1 = quantize_model(transformer_1, "int8-quanto", device=self.device)
self.transformer_stage_1 = _freeze(transformer_1)

サーバー起動は通る。idle VRAM も期待通り:

[load] stage_1 transformer (no distilled LoRA)
[quantize] stage_1 -> int8-quanto
[quantize] stage_1 done in 0.71s
[cuda] after stage_1 transformer: allocated=31.28GiB ...
[load] stage_2 transformer (with distilled LoRA)
[quantize] stage_2 -> int8-quanto
[quantize] stage_2 done in 0.52s
[cuda] after stage_2 transformer: allocated=49.40GiB ...
[server] A2V listening on http://127.0.0.1:8892

resident は 51.7 GiB(bf16 推定の 86 GiB から 40% 削減)。良さそう。

ところが最初の /generate リクエストで:

[timing] prompt_encode=0.75s
[timing] audio_encode=0.39s
  0%|          | 0/30 [00:00<?, ?it/s]
[http] POST /generate 400

step 0/30 で例外。エラー本文を取り出すと:

{"error": "linear(): argument 'weight' (position 2) must be Tensor, not NoneType"}

torch.nn.functional.linear(input, weight=None, bias=None) を呼んでる箇所がある。つまり quanto の freeze() 後、どこかの Linear の self.weight が None として参照されている

なぜ weight=None になるのか

雑な仮説 2 つ:

  1. LTX-2 の Linear は __class__ 再代入を前提Fp8CastLinear がやってるのと同じパターンで、forward を class 単位で書き換えるためにインスタンスの state を保つ前提になっている。quanto は quantize()freeze()nn.Linear を独自の QLinear ラッパに 置換 するので、その過程で weight 属性の参照が壊れた可能性。
  2. EXCLUDE_PATTERNS が blockwise 経路では効いてない。LTX-trainer の _quantize_blockwisetransformer_blocks を 1 ブロックずつ取り出して quantize(block, exclude=EXCLUDE_PATTERNS) を呼ぶ。だが EXCLUDE_PATTERNS の中身は patchify_proj*adaln*time_proj といったモデル全体パス前提の glob で、block 内部から見た相対パスでは一致しない。本来 excluded にしたい層が quantize されている可能性。

どちらにせよ、これを真面目に直すには quanto のラッパ実装 + LTX-2 transformer 全層の forward 経路を読む必要があり、コストが見合わない。深追いはやめて LTX-2 native の fp8_cast に切る判断をした。


fp8_cast で乗り換える

実装変更は 3 行:

# 量子化ポリシーをパイプライン構築時に渡すだけ
pipeline_quantization = None
if transformer_quantization == "fp8-cast":
    from ltx_core.quantization import QuantizationPolicy
    pipeline_quantization = QuantizationPolicy.fp8_cast()

self.pipeline = A2VidPipelineTwoStage(
    ...,
    quantization=pipeline_quantization,
    ...
)

fp8_cast重みロードの段階で fp8 に downcast する。sd_ops は state_dict 読み込み時のフックなので、43 GB の safetensors を fp8 化しながらストリーミングロード。bf16 を一度フルメモリに展開してから量子化する quanto と違い、ピーク VRAM が膨らまないのが嬉しい。

起動すると:

[load] A2VidPipelineTwoStage builders (pipeline_quantization=QuantizationPolicy(sd_ops=...fp8_cast...))
...
[cuda] after stage_1 transformer: allocated=31.30GiB reserved=35.18GiB
[cuda] after stage_2 transformer: allocated=49.43GiB reserved=53.64GiB
[server] A2V listening on http://127.0.0.1:8892

resident allocated 51.7 GiB は int8-quanto と同水準だが、reserved が 53.6 GiB と圧倒的に低い(int8-quanto は 70.9 GiB)。reserved が小さいということは活性化用のヘッドルームが広い。

そして肝心の /generate

{
  "elapsed_seconds": 39.367,
  "peak_vram_gib": 57.918,
  "width": 768, "height": 512, "num_frames": 97
}

動く。これで本筋に乗った。


ベンチマーク

固定条件で persistent + fp8-cast を 3 解像度 × 3 計測:

  • 画像: 1024×512 portrait
  • 音声: Irodori-TTS で生成した 9.08 秒の日本語サンプル
  • プロンプト: "A young woman speaks calmly to the camera in a softly lit room."
  • num_frames: 97 (= 4.04s @ 24fps)
  • seed: 42 固定
解像度平均 elapsed (s)peak VRAM (GiB)
768×512 / 97f39.8457.92
1024×768 / 97f66.7159.06
1280×768 / 97f84.0258.30

特筆点:

  • 3 run の分散ほぼゼロ(seed 固定で出力 mp4 が byte-identical)
  • peak VRAM が解像度にほぼ無依存(57.9 ~ 59.1 GiB)。resident 重みが支配的で、活性化メモリは ~7 GiB 程度
  • 1280×768 が persistent で安定動作。bf16 persistent (peak ~91 GiB) では実質乗らなかった解像度がここで開く

cold-start でも勝つ

production は cold-start で動かしている(A2V は数分に 1-2 回、TTS と同居必須)。fp8_cast policy は パイプライン構築時の sd_ops で適用されるので、cold-start の per-request build にもそのまま効く。

cold-start + fp8-cast 1 回(768×512 / 97f)の実測:

{
  "elapsed_seconds": 88.775,
  "peak_vram_gib": 23.901
}
bf16 cold-startfp8-cast cold-start
1 リクエスト時間~60-90s88.8s(disk I/O 支配的、同水準)
peak VRAM~40 GiB23.9 GiB(~40% 削減)
idle0 GiB0 GiB
同居(TTS+Ditto+Qwen3+MuseTalk)余裕(peak 30 GiB 程度)

速度はディスク I/O が支配しているので fp8 化しても変わらないが、peak VRAM が 16 GiB 浮くのが効く。Qwen3-TTS-vLLM(7 GiB)や MuseTalk warmup と A2V 生成が同時走行しても OOM しない。


どう使い分けるか

ユースケース推奨モード根拠
会話メイン、A2V は時々演出cold-start + fp8-castidle 0、peak 24 GiB、TTS/Ditto と余裕同居
A2V を連発(バッチ生成、自動演出)persistent + fp8-castresident 52 GiB のコストを払う代わりに 40s/req
1024+ 解像度で品質重視persistent + fp8-cast1280×768 が安定動作(bf16 persistent では不可能だった解像度)
1 機 GPU で全部ホストcold-start + fp8-cast持続常駐は 52 GiB 食うので、他サービスとの予算配分次第

production の判断は「会話メインなので cold-start + fp8-cast、課金ユーザーが増えて A2V 連発が必要になったら persistent fp8-cast に切替」。idle 52 GiB を払う判断は ROI 次第。


まとめ

  • LTX-2 22B の bf16 idle 86 GiB は単機 GPU でほぼ寡占。量子化は必須に近い
  • optimum-quanto は LTX-2 transformer と非互換F.linear(weight=None) で死ぬ。__class__ 再代入 / EXCLUDE_PATTERNS の効き方どちらかが原因と推定(深追いせず)
  • LTX-2 native の QuantizationPolicy.fp8_cast() が正解。ロード時に fp8 化、forward 時 bf16 upcast。実装は 3 行
  • cold-start + fp8-cast で peak 40 → 24 GiB、persistent + fp8-cast で 1280×768 が新たに使えるようになる
  • 同じく LTX-2 で fp8_scaled_mm (tensorrt_llm 必須) があるので、TRT を入れる気があるならそちらも試す価値あり

おまけ:起動コマンドと再現条件

production の cold-start + fp8-cast 起動:

PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True nohup uv run python scripts/persistent_a2v_server.py \
  --port 8892 \
  --checkpoint-path models/LTX-2.3/ltx-2.3-22b-dev.safetensors \
  --distilled-lora-path models/loras/ltx-2.3-22b-distilled-lora-384-1.1.safetensors \
  --spatial-upsampler-path models/LTX-2.3/ltx-2.3-spatial-upscaler-x2-1.1.safetensors \
  --gemma-root models/gemma-3-12b-it-qat-q4_0-unquantized \
  --output-dir outputs/a2v_server \
  --transformer-quantization fp8-cast \
  --cold-start \
  > /tmp/ltx_a2v_server.log 2>&1 &

persistent_a2v_server.py 自体は LTX-2 リポの公式スクリプトを A2V 向けに拡張したもの。本記事のフラグ --transformer-quantization fp8-cast は実装パッチを当てて追加。

実装パッチ(要点だけ):

# scripts/persistent_a2v_server.py
pipeline_quantization = None
if transformer_quantization in ("fp8-cast", "fp8-scaled-mm"):
    from ltx_core.quantization import QuantizationPolicy  # late import: 循環参照回避
    pipeline_quantization = (
        QuantizationPolicy.fp8_cast()
        if transformer_quantization == "fp8-cast"
        else QuantizationPolicy.fp8_scaled_mm()
    )

self.pipeline = A2VidPipelineTwoStage(
    ...,
    quantization=pipeline_quantization,
    ...,
)

from ltx_core.quantization import QuantizationPolicy を top-level で書くと ltx_core.loader との循環参照で死ぬので、late import 必須。