Integrating a New Model Architecture
This guide walks you through the end-to-end process of integrating a new foundational model architecture into InvokeAI. This is required when adding a completely new family of models (e.g., Stable Diffusion 3, FLUX, Hunyuan, etc.), rather than just adding a new checkpoint for an existing architecture.
Architectural Overview
Section titled “Architectural Overview”Integrating a new model touches several parts of the InvokeAI stack, from the lowest-level PyTorch inference code up to the React frontend:
- Taxonomy & Configuration (Backend): Declaring the model’s existence and defining how to detect it from its weights on disk.
- Model Loading (Backend): Defining how to load the detected files into PyTorch models in memory.
- Sampling & Denoising (Backend): Implementing the core math for noise generation, scheduling, and the denoising loop.
- Invocations (Backend): Wrapping the PyTorch logic into isolated “nodes” that can be executed by InvokeAI’s graph engine.
- Graph Building (Frontend): Instructing the UI on how to wire these nodes together based on user settings.
- State & UI (Frontend): Adding the necessary UI controls and state management for the new model’s unique parameters.
1. Taxonomy & Defaults
Section titled “1. Taxonomy & Defaults”The first step is to declare your model in the system’s taxonomy and provide reasonable default settings.
-
Add
BaseModelTypeUpdate the base model taxonomy to include your new model.
invokeai/backend/model_manager/taxonomy.py class BaseModelType(str, Enum):# Existing typesStableDiffusion1 = "sd-1"StableDiffusion2 = "sd-2"StableDiffusionXL = "sdxl"Flux = "flux"NewModel = "newmodel" -
Add Variant Type (if needed)
If your model comes in different structural variants (e.g., different parameter counts or distilled versions like
schnellvsdev), define a variant enum.invokeai/backend/model_manager/taxonomy.py class NewModelVariantType(str, Enum):VariantA = "variant_a"VariantB = "variant_b" -
Define Default Settings
Provide default generation parameters (steps, CFG scale, etc.) for the UI to use when this model is selected.
invokeai/backend/model_manager/configs/main.py class MainModelDefaultSettings:@staticmethoddef from_base(base: BaseModelType, variant: AnyVariant | None = None):match base:case BaseModelType.NewModel:return MainModelDefaultSettings(steps=20, cfg_scale=7.0)
2. Model Configs & Detection
Section titled “2. Model Configs & Detection”InvokeAI needs to know how to identify your model from a .safetensors file or a diffusers folder.
-
Create Main Model Config
Define the configuration schemas for your model format(s).
invokeai/backend/model_manager/configs/main.py # Checkpoint Format (Single File)@ModelConfigFactory.registerclass Main_Checkpoint_NewModel_Config(Checkpoint_Config_Base):type: Literal[ModelType.Main] = ModelType.Mainbase: Literal[BaseModelType.NewModel] = BaseModelType.NewModelformat: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpointvariant: NewModelVariantType = NewModelVariantType.VariantA@classmethoddef from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict) -> Self:if not cls._validate_is_newmodel(mod):raise NotAMatchError("Not a NewModel")variant = cls._get_variant_or_raise(mod)return cls(..., variant=variant)# Diffusers Format (Folder)@ModelConfigFactory.registerclass Main_Diffusers_NewModel_Config(Diffusers_Config_Base):type: Literal[ModelType.Main] = ModelType.Mainbase: Literal[BaseModelType.NewModel] = BaseModelType.NewModelformat: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers -
Implement Detection Logic
Write helper functions to inspect the state dictionary keys and shape to uniquely identify your architecture.
invokeai/backend/model_manager/configs/main.py def _is_newmodel(state_dict: dict) -> bool:"""Detect if state dict belongs to NewModel architecture."""# Example: check for a highly specific layer name or shaperequired_keys = ["transformer_blocks.0.attn.to_q.weight"]return all(key in state_dict for key in required_keys)def _get_newmodel_variant(state_dict: dict) -> NewModelVariantType:"""Determine variant from state dict."""# Example: distinguish variants based on hidden dimension sizecontext_dim = state_dict["context_embedder.weight"].shape[1]if context_dim == 7680:return NewModelVariantType.VariantAreturn NewModelVariantType.VariantB -
Submodels (VAE & Text Encoder)
If your model uses a novel VAE or Text Encoder not already in InvokeAI, you must repeat this process to create configs for them (e.g., in
configs/vae.pyandconfigs/[encoder_type].py). -
Update the Configuration Union
Register your new configs so the application knows to check them when scanning directories.
invokeai/backend/model_manager/configs/factory.py AnyModelConfig = Annotated[# ... existing configsMain_Checkpoint_NewModel_Config |Main_Diffusers_NewModel_Config,Discriminator(...)]
3. Model Loaders
Section titled “3. Model Loaders”Loaders are responsible for converting the files on disk (described by the config) into PyTorch models in memory.
-
Create the Model Loader
invokeai/backend/model_manager/load/model_loaders/[newmodel].py @ModelLoaderRegistry.register(base=BaseModelType.NewModel,type=ModelType.Main,format=ModelFormat.Checkpoint)class NewModelLoader(ModelLoader):def _load_model(self, config: AnyModelConfig, submodel_type: SubModelType | None) -> AnyModel:# 1. Load the raw weights from diskstate_dict = self._load_state_dict(config.path)# 2. Convert state dict keys if necessary (e.g. from original repo format to Diffusers)if self._is_original_format(state_dict):state_dict = self._convert_to_diffusers_format(state_dict)# 3. Instantiate the empty PyTorch modelmodel = NewModelTransformer(config=model_config)# 4. Load weights into the modelmodel.load_state_dict(state_dict)return model -
Custom VAE/Encoder Loaders (If Applicable)
If you created custom configs for the VAE or Text Encoder, you must also create loaders for them, registering them with the appropriate
ModelType.
4. Sampling and Denoising Core
Section titled “4. Sampling and Denoising Core”This is where the actual mathematical implementation of the model lives.
-
Sampling Utilities
Create utility functions specific to how your model handles noise, packing, and scheduling.
invokeai/backend/[newmodel]/sampling_utils.py def get_noise_newmodel(num_samples: int, height: int, width: int, seed: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor:# Models often have different latent channel counts (e.g., SD1.5 has 4, FLUX has 16)latent_channels = 32latent_h, latent_w = height // 8, width // 8generator = torch.Generator(device=device).manual_seed(seed)return torch.randn((num_samples, latent_channels, latent_h, latent_w), generator=generator, device=device, dtype=dtype)def pack_newmodel(x: torch.Tensor) -> torch.Tensor:# Some transformer-based models require packing latents into a sequencereturn rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) -
The Denoising Loop
Implement the core sampling loop. This interacts with schedulers and handles classifier-free guidance (CFG).
invokeai/backend/[newmodel]/denoise.py def denoise(model: nn.Module, img: torch.Tensor, txt: torch.Tensor, timesteps: list[float], cfg_scale: list[float], scheduler: Any = None) -> torch.Tensor:"""Main denoising loop."""total_steps = len(timesteps) - 1for step_index in range(total_steps):t_curr = timesteps[step_index]# Handle CFG (Classifier-Free Guidance)if cfg_scale[step_index] > 1.0:# Batch positive and negative prompts if applicablepred_pos = model(img, t_curr, txt)# ...else:pred = model(img, t_curr, txt)# Step the schedulerimg = scheduler.step(pred, t_curr, img).prev_samplereturn img -
Schedulers
If your model requires a novel scheduler, add it to the scheduler mapping (e.g.,
invokeai/backend/[newmodel]/schedulers.py).
5. Invocations
Section titled “5. Invocations”Invocations expose your PyTorch functions as isolated execution nodes in InvokeAI’s graph.
-
Model Loader Invocation
Loads the components (Transformer, VAE, etc.) and provides them to downstream nodes.
invokeai/app/invocations/[newmodel]_model_loader.py @invocation("newmodel_model_loader", title="NewModel Loader", category="model_loader")class NewModelModelLoaderInvocation(BaseInvocation):model: ModelIdentifierField = InputField(description="Main model")def invoke(self, context: InvocationContext) -> NewModelLoaderOutput:transformer = self.model.model_copy(update={"submodel_type": SubModelType.Transformer})vae = self.model.model_copy(update={"submodel_type": SubModelType.VAE})return NewModelLoaderOutput(transformer=transformer, vae=vae) -
Text Encoder Invocation
Tokenizes the prompt and runs the text encoder(s).
invokeai/app/invocations/[newmodel]_text_encoder.py @invocation("newmodel_text_encode", title="NewModel Text Encoder", category="conditioning")class NewModelTextEncoderInvocation(BaseInvocation):prompt: str = InputField()encoder: EncoderField = InputField()def invoke(self, context: InvocationContext) -> ConditioningOutput:# 1. Tokenize prompt# 2. Run encoder to get embeddings# 3. Save to context and returnconditioning_name = context.conditioning.save(ConditioningFieldData(...))return ConditioningOutput(conditioning=ConditioningField(conditioning_name=conditioning_name)) -
Denoise Invocation
Wraps the
denoiseloop you wrote in the previous section.invokeai/app/invocations/[newmodel]_denoise.py @invocation("newmodel_denoise", title="NewModel Denoise", category="latents")class NewModelDenoiseInvocation(BaseInvocation):latents: LatentsField | None = InputField(default=None)positive_conditioning: ConditioningField = InputField()transformer: TransformerField = InputField()steps: int = InputField(default=20)cfg_scale: float = InputField(default=7.0)def invoke(self, context: InvocationContext) -> LatentsOutput:# Generate noise, get schedule, and call your denoise() functionpass -
VAE Encode / Decode Invocations
Create nodes to transition between pixel space (images) and latent space.
6. Frontend: Graph Building
Section titled “6. Frontend: Graph Building”The UI doesn’t know about Python functions; it only knows how to build graphs of Invocations.
-
Create the Graph Builder
Write a TypeScript function that constructs the node graph for your model.
invokeai/frontend/web/src/features/nodes/util/graph/generation/buildNewModelGraph.ts export const buildNewModelGraph = async (arg: GraphBuilderArg): Promise<GraphBuilderResult> => {const { state, manager } = arg;const { model } = state.params;const g = new Graph();// 1. Add Loaderconst modelLoader = g.addNode({id: NEWMODEL_MODEL_LOADER,type: 'newmodel_model_loader',model: Graph.getModelMetadataField(model),});// 2. Add Text Encodersconst positivePrompt = g.addNode({id: POSITIVE_CONDITIONING,type: 'newmodel_text_encode',prompt: state.params.positivePrompt,});g.addEdge(modelLoader, 'encoder', positivePrompt, 'encoder');// 3. Add Denoiseconst denoise = g.addNode({id: NEWMODEL_DENOISE,type: 'newmodel_denoise',steps: state.params.steps,cfg_scale: state.params.cfg,});g.addEdge(modelLoader, 'transformer', denoise, 'transformer');g.addEdge(positivePrompt, 'conditioning', denoise, 'positive_conditioning');// 4. Add VAE Decodeconst l2i = g.addNode({id: NEWMODEL_VAE_DECODE,type: 'newmodel_vae_decode',});g.addEdge(modelLoader, 'vae', l2i, 'vae');g.addEdge(denoise, 'latents', l2i, 'latents');return { g, denoise, posCond: positivePrompt };}; -
Register the Graph Builder
Hook your graph builder into the main routing logic.
invokeai/frontend/web/src/features/queue/hooks/useEnqueueCanvas.ts switch (base) {case 'sdxl':return buildSDXLGraph(arg);case 'flux':return buildFLUXGraph(arg);case 'newmodel':return buildNewModelGraph(arg);} -
Update Type Definitions
Add your new nodes to the strict frontend type unions.
invokeai/frontend/web/src/features/nodes/util/graph/types.ts export type ImageOutputNodes =| 'l2i' | 'flux_vae_decode' | 'sd3_l2i' | 'newmodel_vae_decode'; -
Generation Modes
Update
invokeai/app/invocations/metadata.pyto include your new modes inGENERATION_MODES(e.g.,"newmodel_txt2img","newmodel_img2img").
7. Frontend: State & UI
Section titled “7. Frontend: State & UI”Finally, add any custom UI controls (like a specific scheduler dropdown) and manage their state.
-
Add to Redux State
Update the parameters slice for your model-specific settings.
invokeai/frontend/web/src/features/controlLayers/store/paramsSlice.ts interface ParamsState {// ...newmodelScheduler: 'euler' | 'heun';}const initialState: ParamsState = {// ...newmodelScheduler: 'euler',};// Add reducers and export selectors... -
Parameter Recall
Ensure users can extract parameters from previously generated images by updating
invokeai/frontend/web/src/features/metadata/parsing.tsx.invokeai/frontend/web/src/features/metadata/parsing.tsx const recallNewmodelScheduler = (metadata: CoreMetadata) => {if (metadata.scheduler) {dispatch(setNewmodelScheduler(metadata.scheduler));}};
8. Optional Features
Section titled “8. Optional Features”Depending on the model, you may want to support additional features.
ControlNet Support
Section titled “ControlNet Support”Requires backend configuration (configs/controlnet.py), a custom invocation ([newmodel]_controlnet.py), and frontend graph integration (addControlNets).
LoRA Support
Section titled “LoRA Support”Requires defining a LoRA config (configs/lora.py), updating the model loader to pass LoRA fields, and wiring addLoRAs in the frontend graph builder.
IP-Adapter
Section titled “IP-Adapter”Requires a custom invocation for image prompting ([newmodel]_ip_adapter.py) and frontend integration via addIPAdapters.
9. Starter Models
Section titled “9. Starter Models”To allow users to easily download your model from the Model Manager UI, add it to the starter models list.
newmodel_main = StarterModel( name="NewModel Main", base=BaseModelType.NewModel, source="organization/newmodel-main", # HuggingFace repo description="NewModel main transformer.", type=ModelType.Main,)
STARTER_MODELS.append(newmodel_main)Summary of Integration Files
Section titled “Summary of Integration Files”A complete minimal txt2img integration touches the following areas:
Directoryinvokeai
Directoryapp/invocations
- metadata.py
[newmodel]_model_loader.py[newmodel]_text_encoder.py[newmodel]_denoise.py[newmodel]_vae_decode.py
Directorybackend
Directorymodel_manager
- taxonomy.py
Directoryconfigs
- main.py
- factory.py
Directoryload/model_loaders
[newmodel].py
- starter_models.py
Directory
[newmodel]- sampling_utils.py
- denoise.py
Directoryfrontend/web/src/features
Directorynodes/util/graph
- generation/buildNewModelGraph.ts
- types.ts
- queue/hooks/useEnqueueCanvas.ts
- controlLayers/store/paramsSlice.ts
- metadata/parsing.tsx