LTX-2 22B を fp8_cast で peak VRAM 40% 削減した話 — optimum-quanto は罠だった
この記事は zenn.dev
はじめに
LTX-2.3 は Lightricks が出している音声付き動画生成モデルで、A2V(Audio-to-Video)モードを使うと 画像 1 枚 + 音声 + プロンプト からリップシンク・表情・頭部や髪の動きを一体で生成できる。MuseTalk のような「リップシンクだけ」のモデルと違い、シーン全体を動かせるので「演出経路」として強力。
ただし base checkpoint が 22B パラメータ・43 GB あり、bf16 で常駐させると transformer × 2 stage で idle ~86 GiB。RTX PRO 6000 Blackwell の 96 GiB をほぼ食い切ってしまい、同居している TTS / Ditto-TalkingHead / Qwen3-TTS-vLLM が押し出される。
そこで量子化を試した結果、LTX-2 native の fp8_cast で peak VRAM を 40 GiB → 24 GiB に圧縮できた(A2V cold-start, 768×512 / 97f)。一方で optimum-quanto の int8/fp8 は LTX-2 transformer と互換性問題で動かなかった。本記事はそのデバッグと判断の記録。
環境
- GPU: NVIDIA RTX PRO 6000 Blackwell Max-Q Workstation Edition (96 GiB)
- PyTorch: 2.9.1 + CUDA 12.8
- モデル: LTX-2.3 22B-dev (base) + 22B-distilled-lora-384 (stage_2) + Gemma-3-12B text encoder (bnb 4bit)
- 運用形態: A2V を
scripts/persistent_a2v_server.py --cold-startで運用。リクエストごとにbuild → run → free、idle 0 GiB。
cold-start を採用してる理由は別記事に書いた(or 別途)。要は 会話メインで A2V は時々呼ぶ、TTS / Ditto と同居必須 のため。
候補は 4 つ
LTX-2 のコードベースを見ると、量子化経路は実は 2 系統ある:
1. LTX-2 native:QuantizationPolicy
packages/ltx-core/src/ltx_core/quantization/policy.py:
@dataclass(frozen=True)
class QuantizationPolicy:
sd_ops: SDOps | None = None # state dict ロード時の重み変換
module_ops: tuple[ModuleOps, ...] = () # ロード後のモジュール書換
@classmethod
def fp8_cast(cls) -> "QuantizationPolicy":
"""weight を float8_e4m3fn でロード、forward 時 bf16 にアップキャスト"""
return cls(
sd_ops=TRANSFORMER_LINEAR_DOWNCAST_MAP,
module_ops=(UPCAST_DURING_INFERENCE,),
)
@classmethod
def fp8_scaled_mm(cls) -> "QuantizationPolicy":
"""FP8 scaled MM(tensorrt_llm 必須)"""
fp8_cast の実体は Fp8CastLinear:
class Fp8CastLinear(torch.nn.Linear):
def forward(self, input):
w_up = _upcast_and_round(self.weight, input.dtype, ...)
b_up = _upcast_and_round(self.bias, input.dtype, ...) if self.bias is not None else None
return torch.nn.functional.linear(input, w_up, b_up)
__class__ 再代入でインスタンスを書き換える型変換パターン。重みは fp8 で保持し、forward の都度 bf16 にキャストして matmul する。fp8 → bf16 のキャストコストは付くが、Blackwell ならほぼノイズレベル。
2. optimum-quanto
LTX-2 trainer package (packages/ltx-trainer) には optimum-quanto を使った汎用量子化があり、int8-quanto / int4-quanto / fp8-quanto を選べる:
def quantize_model(model, precision, ...):
if hasattr(model, "transformer_blocks"):
_quantize_blockwise(model, ...) # 1 ブロックずつ GPU に上げて quantize → freeze → CPU
else:
quantize(model, weights=..., exclude=EXCLUDE_PATTERNS)
freeze(model)
return model
これを _build_transformer() の直後に挟めば良さそうに見える。
4 候補のマトリクス
| mode | 経路 | 期待 |
|---|---|---|
fp8-cast | LTX-2 native、sd_ops で float8_e4m3fn ロード | ~半減・速度ほぼ同等 |
fp8-scaled-mm | LTX-2 native、tensorrt_llm 必須 | より高速 |
int8-quanto | optimum-quanto、post-build | ~半減・速度 ± |
fp8-quanto | 同上、fp8 版 | Blackwell の native FP8 を踏める可能性 |
fp8-scaled-mm は tensorrt_llm が入っていない環境なので skip。残り 3 つを実装した。
まず int8-quanto で踏み抜く
実装は素直:
from ltx_trainer.quantization import quantize_model
transformer_1 = self.pipeline.stage_1._build_transformer()
transformer_1 = quantize_model(transformer_1, "int8-quanto", device=self.device)
self.transformer_stage_1 = _freeze(transformer_1)
サーバー起動は通る。idle VRAM も期待通り:
[load] stage_1 transformer (no distilled LoRA)
[quantize] stage_1 -> int8-quanto
[quantize] stage_1 done in 0.71s
[cuda] after stage_1 transformer: allocated=31.28GiB ...
[load] stage_2 transformer (with distilled LoRA)
[quantize] stage_2 -> int8-quanto
[quantize] stage_2 done in 0.52s
[cuda] after stage_2 transformer: allocated=49.40GiB ...
[server] A2V listening on http://127.0.0.1:8892
resident は 51.7 GiB(bf16 推定の 86 GiB から 40% 削減)。良さそう。
ところが最初の /generate リクエストで:
[timing] prompt_encode=0.75s
[timing] audio_encode=0.39s
0%| | 0/30 [00:00<?, ?it/s]
[http] POST /generate 400
step 0/30 で例外。エラー本文を取り出すと:
{"error": "linear(): argument 'weight' (position 2) must be Tensor, not NoneType"}
torch.nn.functional.linear(input, weight=None, bias=None) を呼んでる箇所がある。つまり quanto の freeze() 後、どこかの Linear の self.weight が None として参照されている。
なぜ weight=None になるのか
雑な仮説 2 つ:
- LTX-2 の Linear は
__class__再代入を前提。Fp8CastLinearがやってるのと同じパターンで、forward を class 単位で書き換えるためにインスタンスの state を保つ前提になっている。quanto はquantize()→freeze()でnn.Linearを独自のQLinearラッパに 置換 するので、その過程でweight属性の参照が壊れた可能性。 EXCLUDE_PATTERNSが blockwise 経路では効いてない。LTX-trainer の_quantize_blockwiseはtransformer_blocksを 1 ブロックずつ取り出してquantize(block, exclude=EXCLUDE_PATTERNS)を呼ぶ。だがEXCLUDE_PATTERNSの中身はpatchify_proj、*adaln*、time_projといったモデル全体パス前提の glob で、block 内部から見た相対パスでは一致しない。本来 excluded にしたい層が quantize されている可能性。
どちらにせよ、これを真面目に直すには quanto のラッパ実装 + LTX-2 transformer 全層の forward 経路を読む必要があり、コストが見合わない。深追いはやめて LTX-2 native の fp8_cast に切る判断をした。
fp8_cast で乗り換える
実装変更は 3 行:
# 量子化ポリシーをパイプライン構築時に渡すだけ
pipeline_quantization = None
if transformer_quantization == "fp8-cast":
from ltx_core.quantization import QuantizationPolicy
pipeline_quantization = QuantizationPolicy.fp8_cast()
self.pipeline = A2VidPipelineTwoStage(
...,
quantization=pipeline_quantization,
...
)
fp8_cast は 重みロードの段階で fp8 に downcast する。sd_ops は state_dict 読み込み時のフックなので、43 GB の safetensors を fp8 化しながらストリーミングロード。bf16 を一度フルメモリに展開してから量子化する quanto と違い、ピーク VRAM が膨らまないのが嬉しい。
起動すると:
[load] A2VidPipelineTwoStage builders (pipeline_quantization=QuantizationPolicy(sd_ops=...fp8_cast...))
...
[cuda] after stage_1 transformer: allocated=31.30GiB reserved=35.18GiB
[cuda] after stage_2 transformer: allocated=49.43GiB reserved=53.64GiB
[server] A2V listening on http://127.0.0.1:8892
resident allocated 51.7 GiB は int8-quanto と同水準だが、reserved が 53.6 GiB と圧倒的に低い(int8-quanto は 70.9 GiB)。reserved が小さいということは活性化用のヘッドルームが広い。
そして肝心の /generate:
{
"elapsed_seconds": 39.367,
"peak_vram_gib": 57.918,
"width": 768, "height": 512, "num_frames": 97
}
動く。これで本筋に乗った。
ベンチマーク
固定条件で persistent + fp8-cast を 3 解像度 × 3 計測:
- 画像: 1024×512 portrait
- 音声: Irodori-TTS で生成した 9.08 秒の日本語サンプル
- プロンプト: "A young woman speaks calmly to the camera in a softly lit room."
- num_frames: 97 (= 4.04s @ 24fps)
- seed: 42 固定
| 解像度 | 平均 elapsed (s) | peak VRAM (GiB) |
|---|---|---|
| 768×512 / 97f | 39.84 | 57.92 |
| 1024×768 / 97f | 66.71 | 59.06 |
| 1280×768 / 97f | 84.02 | 58.30 |
特筆点:
- 3 run の分散ほぼゼロ(seed 固定で出力 mp4 が byte-identical)
- peak VRAM が解像度にほぼ無依存(57.9 ~ 59.1 GiB)。resident 重みが支配的で、活性化メモリは ~7 GiB 程度
- 1280×768 が persistent で安定動作。bf16 persistent (peak ~91 GiB) では実質乗らなかった解像度がここで開く
cold-start でも勝つ
production は cold-start で動かしている(A2V は数分に 1-2 回、TTS と同居必須)。fp8_cast policy は パイプライン構築時の sd_ops で適用されるので、cold-start の per-request build にもそのまま効く。
cold-start + fp8-cast 1 回(768×512 / 97f)の実測:
{
"elapsed_seconds": 88.775,
"peak_vram_gib": 23.901
}
| bf16 cold-start | fp8-cast cold-start | |
|---|---|---|
| 1 リクエスト時間 | ~60-90s | 88.8s(disk I/O 支配的、同水準) |
| peak VRAM | ~40 GiB | 23.9 GiB(~40% 削減) |
| idle | 0 GiB | 0 GiB |
| 同居(TTS+Ditto+Qwen3+MuseTalk) | 可 | 余裕(peak 30 GiB 程度) |
速度はディスク I/O が支配しているので fp8 化しても変わらないが、peak VRAM が 16 GiB 浮くのが効く。Qwen3-TTS-vLLM(7 GiB)や MuseTalk warmup と A2V 生成が同時走行しても OOM しない。
どう使い分けるか
| ユースケース | 推奨モード | 根拠 |
|---|---|---|
| 会話メイン、A2V は時々演出 | cold-start + fp8-cast | idle 0、peak 24 GiB、TTS/Ditto と余裕同居 |
| A2V を連発(バッチ生成、自動演出) | persistent + fp8-cast | resident 52 GiB のコストを払う代わりに 40s/req |
| 1024+ 解像度で品質重視 | persistent + fp8-cast | 1280×768 が安定動作(bf16 persistent では不可能だった解像度) |
| 1 機 GPU で全部ホスト | cold-start + fp8-cast | 持続常駐は 52 GiB 食うので、他サービスとの予算配分次第 |
production の判断は「会話メインなので cold-start + fp8-cast、課金ユーザーが増えて A2V 連発が必要になったら persistent fp8-cast に切替」。idle 52 GiB を払う判断は ROI 次第。
まとめ
- LTX-2 22B の bf16 idle 86 GiB は単機 GPU でほぼ寡占。量子化は必須に近い
optimum-quantoは LTX-2 transformer と非互換。F.linear(weight=None)で死ぬ。__class__再代入 /EXCLUDE_PATTERNSの効き方どちらかが原因と推定(深追いせず)- LTX-2 native の
QuantizationPolicy.fp8_cast()が正解。ロード時に fp8 化、forward 時 bf16 upcast。実装は 3 行 - cold-start + fp8-cast で peak 40 → 24 GiB、persistent + fp8-cast で 1280×768 が新たに使えるようになる
- 同じく LTX-2 で
fp8_scaled_mm(tensorrt_llm 必須) があるので、TRT を入れる気があるならそちらも試す価値あり
おまけ:起動コマンドと再現条件
production の cold-start + fp8-cast 起動:
PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True nohup uv run python scripts/persistent_a2v_server.py \
--port 8892 \
--checkpoint-path models/LTX-2.3/ltx-2.3-22b-dev.safetensors \
--distilled-lora-path models/loras/ltx-2.3-22b-distilled-lora-384-1.1.safetensors \
--spatial-upsampler-path models/LTX-2.3/ltx-2.3-spatial-upscaler-x2-1.1.safetensors \
--gemma-root models/gemma-3-12b-it-qat-q4_0-unquantized \
--output-dir outputs/a2v_server \
--transformer-quantization fp8-cast \
--cold-start \
> /tmp/ltx_a2v_server.log 2>&1 &
persistent_a2v_server.py 自体は LTX-2 リポの公式スクリプトを A2V 向けに拡張したもの。本記事のフラグ --transformer-quantization fp8-cast は実装パッチを当てて追加。
実装パッチ(要点だけ):
# scripts/persistent_a2v_server.py
pipeline_quantization = None
if transformer_quantization in ("fp8-cast", "fp8-scaled-mm"):
from ltx_core.quantization import QuantizationPolicy # late import: 循環参照回避
pipeline_quantization = (
QuantizationPolicy.fp8_cast()
if transformer_quantization == "fp8-cast"
else QuantizationPolicy.fp8_scaled_mm()
)
self.pipeline = A2VidPipelineTwoStage(
...,
quantization=pipeline_quantization,
...,
)
from ltx_core.quantization import QuantizationPolicy を top-level で書くと ltx_core.loader との循環参照で死ぬので、late import 必須。