feat: add per-model FP8 layerwise casting for VRAM reduction (#8945)
* feat: add per-model FP8 layerwise casting for VRAM reduction Add fp8_storage option to model default settings that enables diffusers' enable_layerwise_casting() to store weights in FP8 (float8_e4m3fn) while casting to fp16/bf16 during inference. This reduces VRAM usage by ~50% per model with minimal quality loss. Supported: SD1/SD2/SDXL/SD3, Flux, Flux2, CogView4, Z-Image, VAE (diffusers-based), ControlNet, T2IAdapter. Not applicable: Text Encoders, LoRA, GGUF, BnB, custom classes * feat: add FP8 storage option to Model Manager UI Add per-model FP8 storage toggle in Model Manager default settings for both main models and control adapter models. When enabled, model weights are stored in FP8 format in VRAM (~50% savings) and cast layer-by-layer to compute precision during inference via diffusers' enable_layerwise_casting(). Backend: add fp8_storage field to MainModelDefaultSettings and ControlAdapterDefaultSettings, apply FP8 layerwise casting in all relevant model loaders (SD, SDXL, FLUX, CogView4, Z-Image, ControlNet, T2IAdapter, VAE). Gracefully skips non-ModelMixin models (custom checkpoint loaders, GGUF, BnB). Frontend: add FP8 Storage switch to model default settings panels with InformationalPopover, translation keys, and proper form handling. * ruff format * fix: enable FP8 layerwise casting for checkpoint Flux models FluxCheckpointModel and Flux2CheckpointModel were missing the _apply_fp8_layerwise_casting call. Additionally, the FP8 casting only worked for diffusers ModelMixin models. Add manual layerwise casting via forward hooks for plain nn.Module (custom Flux class). Also simplify FP8 UI toggle from dual-slider to single switch, matching the CPU-only toggle pattern per review feedback on #8945. * fix: exclude Z-Image from FP8 due to diffusers layerwise casting bug Z-Image's transformer has dtype mismatches with diffusers' enable_layerwise_casting: skipped modules (t_embedder, cap_embedder) stay in bf16 while hooked modules cast to fp16, causing crashes in attention layers. Also hide the FP8 toggle in the UI for Z-Image models. * fix: detect model dtype for FP8 compute instead of using global dtype Models like Flux are loaded in bf16 but the global torch dtype is fp16, causing dtype mismatches during FP8 layerwise casting. Detect the model's actual parameter dtype and use it as compute_dtype for both diffusers ModelMixin and plain nn.Module models. * Remove call for _should_use_fp8 in z-image * Merge branch 'main' + exclude VAEs from FP8 layerwise casting Resolve merge conflict in vae.py by keeping upstream's Anima/QwenImage VAE loader paths and dropping the FP8 call from the AutoencoderKL checkpoint path. Exclude VAEs from FP8 layerwise casting in _should_use_fp8 (both standalone ModelType.VAE and the VAE/VAEDecoder/VAEEncoder submodel types of Main models). FP8 storage causes noticeable quality degradation on VAE decode. * fix(fp8): invalidate cache on settings change, exception-safe nn.Module fallback, hide ControlLoRA toggle - Add ModelCache.drop_model() and call it from update_model_record when fp8_storage or cpu_only change. These settings are baked into the loaded nn.Module at load time, so toggling them was silently a no-op until the cache entry was evicted by other means. - Replace the pre-hook/post-hook pair in _apply_fp8_to_nn_module with a forward wrapper using try/finally. register_forward_hook only fires on successful forward, so an exception left params in compute dtype and defeated the FP8 storage savings. - Hide the FP8 toggle in the UI for ControlLoRA and exclude LoRA/ControlLoRA in _should_use_fp8. LoRAs are patched into base models rather than run as a standalone forward pass, so layerwise-casting hooks would never fire. - Add tests for drop_model, the exception-safe FP8 wrapper, the ControlLoRA/LoRA exclusion, and the _load_settings_changed predicate. * fix(fp8): honor class swap for LoRA patches, evict stale locked entries, skip precision-sensitive layers - _wrap_forward_with_fp8_cast now dispatches via type(module).forward at call time instead of capturing the bound method. ModelCache.put() swaps nn.Linear.__class__ to CustomLinear (sharing __dict__), which would otherwise leave our instance forward shadowing CustomLinear.forward and silently bypass LoRA/ControlLoRA patch dispatch on FP8 checkpoints. - drop_model() now marks locked entries is_stale instead of skipping them silently; unlock() evicts stale entries once the last lock releases. Without this, a setting toggled during an in-flight generation survived on the locked entry and the next generation reused the pre-change module. - _apply_fp8_to_nn_module mirrors diffusers' apply_layerwise_casting: only the supported layer classes (Linear/Conv*/Embedding) get cast, and module paths matching norm/pos_embed/patch_embed/proj_in/proj_out are skipped. FLUX RMSNorm.scale and similar precision-sensitive scalars are no longer crushed to FP8. - drop_model() and the unlock-stale path now update stats.cleared and fire on_cache_models_cleared callbacks, matching _make_room_internal so the UI stats panel and observers don't miss invalidations. - Add 14 tests: class-swap dispatch, norm/pos_embed/proj_in_out skip, unsupported-type skip, stale-marking, multi-lock release, stats and callback firing for both paths, no-op silence. * fix(fp8): switch nn.Module FP8 wrapper to hooks so CustomLinear dispatch survives apply_custom_layers_to_model Previous fix was wrong. `apply_custom_layers_to_model` does not do `module.__class__ = CustomLinear` — `wrap_custom_layer` constructs a NEW CustomLinear via __new__ and shares the original Linear's __dict__, then setattr installs the new object on the parent. The new object has type() == CustomLinear, but our wrapped forward closed over the original Linear instance, so `type(module).forward(module, ...)` resolved to Linear.forward on the captured old object and silently bypassed CustomLinear.forward — breaking LoRA/ControlLoRA patch dispatch for FP8 checkpoint models. Reproduced on a fresh worktree. Replace the instance-forward override with register_forward_pre_hook + register_forward_hook(always_call=True). Hooks are dispatched by nn.Module._call_impl with the actual called instance, so they fire on the new CustomLinear and self.forward resolves normally via class lookup — reaching CustomLinear.forward and its patch-aware branch. always_call=True keeps the exception-safety guarantee (post-hook fires even when forward raises). Replace the simulated __class__-swap test with one that runs real apply_custom_layers_to_model, attaches a sentinel _patches_and_weights, and asserts the patch-aware branch in CustomLinear.forward is reached. Verified the test fails under the old instance-forward implementation with the reviewer-described symptom and passes under the hook fix. * Add docs for fp8 --------- Co-authored-by: Jonathan <34005131+JPPhoto@users.noreply.github.com>
A
Alexander Eichhorn committed
6f42ad0ee78f61216bf8289faaab37ad15ae9934
Parent: 0f937ce
Committed by GitHub <noreply@github.com>
on 5/12/2026, 2:43:38 AM