SIGN IN SIGN UP

feat: add per-model FP8 layerwise casting for VRAM reduction (#8945)

* feat: add per-model FP8 layerwise casting for VRAM reduction

Add fp8_storage option to model default settings that enables
diffusers' enable_layerwise_casting() to store weights in FP8
(float8_e4m3fn) while casting to fp16/bf16 during inference.
This reduces VRAM usage by ~50% per model with minimal quality loss.

Supported: SD1/SD2/SDXL/SD3, Flux, Flux2, CogView4, Z-Image,
VAE (diffusers-based), ControlNet, T2IAdapter.
Not applicable: Text Encoders, LoRA, GGUF, BnB, custom classes

* feat: add FP8 storage option to Model Manager UI

Add per-model FP8 storage toggle in Model Manager default settings for
both main models and control adapter models. When enabled, model weights
are stored in FP8 format in VRAM (~50% savings) and cast layer-by-layer
to compute precision during inference via diffusers' enable_layerwise_casting().

Backend: add fp8_storage field to MainModelDefaultSettings and
ControlAdapterDefaultSettings, apply FP8 layerwise casting in all
relevant model loaders (SD, SDXL, FLUX, CogView4, Z-Image, ControlNet,
T2IAdapter, VAE). Gracefully skips non-ModelMixin models (custom
checkpoint loaders, GGUF, BnB).

Frontend: add FP8 Storage switch to model default settings panels with
InformationalPopover, translation keys, and proper form handling.

* ruff format

* fix: enable FP8 layerwise casting for checkpoint Flux models

FluxCheckpointModel and Flux2CheckpointModel were missing the
_apply_fp8_layerwise_casting call. Additionally, the FP8 casting
only worked for diffusers ModelMixin models. Add manual layerwise
casting via forward hooks for plain nn.Module (custom Flux class).

Also simplify FP8 UI toggle from dual-slider to single switch,
matching the CPU-only toggle pattern per review feedback on #8945.

* fix: exclude Z-Image from FP8 due to diffusers layerwise casting bug

Z-Image's transformer has dtype mismatches with diffusers'
enable_layerwise_casting: skipped modules (t_embedder, cap_embedder)
stay in bf16 while hooked modules cast to fp16, causing crashes in
attention layers. Also hide the FP8 toggle in the UI for Z-Image models.

* fix: detect model dtype for FP8 compute instead of using global dtype

Models like Flux are loaded in bf16 but the global torch dtype is fp16,
causing dtype mismatches during FP8 layerwise casting. Detect the
model's actual parameter dtype and use it as compute_dtype for both
diffusers ModelMixin and plain nn.Module models.

* Remove call for _should_use_fp8 in z-image

* Merge branch 'main' + exclude VAEs from FP8 layerwise casting

Resolve merge conflict in vae.py by keeping upstream's Anima/QwenImage VAE
loader paths and dropping the FP8 call from the AutoencoderKL checkpoint
path.

Exclude VAEs from FP8 layerwise casting in _should_use_fp8 (both standalone
ModelType.VAE and the VAE/VAEDecoder/VAEEncoder submodel types of Main
models). FP8 storage causes noticeable quality degradation on VAE decode.

* fix(fp8): invalidate cache on settings change, exception-safe nn.Module fallback, hide ControlLoRA toggle

- Add ModelCache.drop_model() and call it from update_model_record when
  fp8_storage or cpu_only change. These settings are baked into the loaded
  nn.Module at load time, so toggling them was silently a no-op until the
  cache entry was evicted by other means.
- Replace the pre-hook/post-hook pair in _apply_fp8_to_nn_module with a
  forward wrapper using try/finally. register_forward_hook only fires on
  successful forward, so an exception left params in compute dtype and
  defeated the FP8 storage savings.
- Hide the FP8 toggle in the UI for ControlLoRA and exclude LoRA/ControlLoRA
  in _should_use_fp8. LoRAs are patched into base models rather than run as
  a standalone forward pass, so layerwise-casting hooks would never fire.
- Add tests for drop_model, the exception-safe FP8 wrapper, the
  ControlLoRA/LoRA exclusion, and the _load_settings_changed predicate.

* fix(fp8): honor class swap for LoRA patches, evict stale locked entries, skip precision-sensitive layers

- _wrap_forward_with_fp8_cast now dispatches via type(module).forward at
  call time instead of capturing the bound method. ModelCache.put() swaps
  nn.Linear.__class__ to CustomLinear (sharing __dict__), which would
  otherwise leave our instance forward shadowing CustomLinear.forward and
  silently bypass LoRA/ControlLoRA patch dispatch on FP8 checkpoints.
- drop_model() now marks locked entries is_stale instead of skipping them
  silently; unlock() evicts stale entries once the last lock releases.
  Without this, a setting toggled during an in-flight generation survived
  on the locked entry and the next generation reused the pre-change module.
- _apply_fp8_to_nn_module mirrors diffusers' apply_layerwise_casting: only
  the supported layer classes (Linear/Conv*/Embedding) get cast, and module
  paths matching norm/pos_embed/patch_embed/proj_in/proj_out are skipped.
  FLUX RMSNorm.scale and similar precision-sensitive scalars are no longer
  crushed to FP8.
- drop_model() and the unlock-stale path now update stats.cleared and fire
  on_cache_models_cleared callbacks, matching _make_room_internal so the
  UI stats panel and observers don't miss invalidations.
- Add 14 tests: class-swap dispatch, norm/pos_embed/proj_in_out skip,
  unsupported-type skip, stale-marking, multi-lock release, stats and
  callback firing for both paths, no-op silence.

* fix(fp8): switch nn.Module FP8 wrapper to hooks so CustomLinear dispatch survives apply_custom_layers_to_model

Previous fix was wrong. `apply_custom_layers_to_model` does not do
`module.__class__ = CustomLinear` — `wrap_custom_layer` constructs a NEW
CustomLinear via __new__ and shares the original Linear's __dict__, then
setattr installs the new object on the parent. The new object has
type() == CustomLinear, but our wrapped forward closed over the original
Linear instance, so `type(module).forward(module, ...)` resolved to
Linear.forward on the captured old object and silently bypassed
CustomLinear.forward — breaking LoRA/ControlLoRA patch dispatch for
FP8 checkpoint models. Reproduced on a fresh worktree.

Replace the instance-forward override with register_forward_pre_hook +
register_forward_hook(always_call=True). Hooks are dispatched by
nn.Module._call_impl with the actual called instance, so they fire on
the new CustomLinear and self.forward resolves normally via class lookup
— reaching CustomLinear.forward and its patch-aware branch.
always_call=True keeps the exception-safety guarantee (post-hook fires
even when forward raises).

Replace the simulated __class__-swap test with one that runs real
apply_custom_layers_to_model, attaches a sentinel _patches_and_weights,
and asserts the patch-aware branch in CustomLinear.forward is reached.
Verified the test fails under the old instance-forward implementation
with the reviewer-described symptom and passes under the hook fix.

* Add docs for fp8

---------

Co-authored-by: Jonathan <34005131+JPPhoto@users.noreply.github.com>
A
Alexander Eichhorn committed
6f42ad0ee78f61216bf8289faaab37ad15ae9934
Parent: 0f937ce
Committed by GitHub <noreply@github.com> on 5/12/2026, 2:43:38 AM