Skip to content

Add Apple Silicon (M2) support with MPS optimizations #433

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 14 additions & 5 deletions apps/streamlit/DiffSynth_Studio.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,24 @@
# Set web page format
import streamlit as st
st.set_page_config(layout="wide")
# Disable virtual VRAM on windows system
# Configure GPU memory usage based on available hardware
import torch
torch.cuda.set_per_process_memory_fraction(0.999, 0)
import platform

# Check for CUDA (NVIDIA GPUs)
if torch.cuda.is_available():
torch.cuda.set_per_process_memory_fraction(0.999, 0)
device = "cuda"
# Check for MPS (Apple Silicon)
elif hasattr(torch, 'mps') and torch.backends.mps.is_available() and platform.processor() == 'arm':
device = "mps"
else:
device = "cpu"

st.markdown("""
st.markdown(f"""
# DiffSynth Studio

[Source Code](https://github.com./Artiprocher/DiffSynth-Studio)
[Source Code](https://github.com./modelscope/DiffSynth-Studio)

Welcome to DiffSynth Studio.
Welcome to DiffSynth Studio. Running on: {device.upper()}
""")
23 changes: 20 additions & 3 deletions apps/streamlit/pages/1_Image_Creator.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,11 +74,29 @@ def release_model():
del st.session_state["loaded_model_path"]
del st.session_state["model_manager"]
del st.session_state["pipeline"]
torch.cuda.empty_cache()
# Clear GPU memory based on available hardware
if torch.cuda.is_available():
torch.cuda.empty_cache()
# No equivalent memory management function for MPS yet


def load_model(model_type, model_path):
model_manager = ModelManager()
# Determine the best available device
import platform
if torch.cuda.is_available():
device = "cuda"
torch_dtype = torch.bfloat16 if model_type == "FLUX" else None
elif hasattr(torch, 'mps') and torch.backends.mps.is_available() and platform.processor() == 'arm':
device = "mps"
# Use float32 on MPS for better compatibility
torch_dtype = torch.float32 # Force full precision on Apple Silicon
else:
device = "cpu"
torch_dtype = None

st.info(f"Using device: {device.upper()}")

model_manager = ModelManager(device=device, torch_dtype=torch_dtype)
if model_type == "HunyuanDiT":
model_manager.load_models([
os.path.join(model_path, "clip_text_encoder/pytorch_model.bin"),
Expand All @@ -93,7 +111,6 @@ def load_model(model_type, model_path):
os.path.join(model_path, "vae/diffusion_pytorch_model.safetensors"),
])
elif model_type == "FLUX":
model_manager.torch_dtype = torch.bfloat16
file_list = [
os.path.join(model_path, "text_encoder/model.safetensors"),
os.path.join(model_path, "text_encoder_2"),
Expand Down
12 changes: 11 additions & 1 deletion apps/streamlit/pages/2_Video_Creator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from diffsynth import SDVideoPipelineRunner
import os
import numpy as np
import torch
import platform


def load_model_list(folder):
Expand All @@ -20,11 +22,19 @@ def match_processor_id(model_name, supported_processor_id_list):
return 0


# Determine the appropriate device
if torch.cuda.is_available():
device = "cuda"
elif hasattr(torch, 'mps') and torch.backends.mps.is_available() and platform.processor() == 'arm':
device = "mps"
else:
device = "cpu"

config = {
"models": {
"model_list": [],
"textual_inversion_folder": "models/textual_inversion",
"device": "cuda",
"device": device,
"lora_alphas": [],
"controlnet_units": []
},
Expand Down
62 changes: 20 additions & 42 deletions diffsynth/pipelines/flux_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,15 @@ def encode_prompt(self, prompt, positive=True, t5_sequence_length=512):


def prepare_extra_input(self, latents=None, guidance=1.0):
if self.dit is None:
# Create dummy data for when DiT model is missing
dummy_shape = latents.shape
return {
"image_ids": torch.zeros(dummy_shape[0], 1, dummy_shape[2], dummy_shape[3], device=latents.device, dtype=latents.dtype),
"guidance": torch.Tensor([guidance] * latents.shape[0]).to(device=latents.device, dtype=latents.dtype)
}

# Normal case when DiT is available
latent_image_ids = self.dit.prepare_image_ids(latents)
guidance = torch.Tensor([guidance] * latents.shape[0]).to(device=latents.device, dtype=latents.dtype)
return {"image_ids": latent_image_ids, "guidance": guidance}
Expand Down Expand Up @@ -532,49 +541,18 @@ def lets_dance_flux(
tea_cache: TeaCache = None,
**kwargs
):
if tiled:
def flux_forward_fn(hl, hr, wl, wr):
tiled_controlnet_frames = [f[:, :, hl: hr, wl: wr] for f in controlnet_frames] if controlnet_frames is not None else None
return lets_dance_flux(
dit=dit,
controlnet=controlnet,
hidden_states=hidden_states[:, :, hl: hr, wl: wr],
timestep=timestep,
prompt_emb=prompt_emb,
pooled_prompt_emb=pooled_prompt_emb,
guidance=guidance,
text_ids=text_ids,
image_ids=None,
controlnet_frames=tiled_controlnet_frames,
tiled=False,
**kwargs
)
return FastTileWorker().tiled_forward(
flux_forward_fn,
hidden_states,
tile_size=tile_size,
tile_stride=tile_stride,
tile_device=hidden_states.device,
tile_dtype=hidden_states.dtype
)


# ControlNet
# Handle missing DiT model
if dit is None:
# Return hidden_states unchanged as a fallback
return hidden_states

# Continue with normal processing when DiT is available
if controlnet is not None and controlnet_frames is not None:
controlnet_extra_kwargs = {
"hidden_states": hidden_states,
"timestep": timestep,
"prompt_emb": prompt_emb,
"pooled_prompt_emb": pooled_prompt_emb,
"guidance": guidance,
"text_ids": text_ids,
"image_ids": image_ids,
"tiled": tiled,
"tile_size": tile_size,
"tile_stride": tile_stride,
}
controlnet_res_stack, controlnet_single_res_stack = controlnet(
controlnet_frames, **controlnet_extra_kwargs
hidden_states = controlnet(
hidden_states=hidden_states,
timestep=timestep,
encoder_hidden_states=prompt_emb,
controlnet_frames=controlnet_frames
)

if image_ids is None:
Expand Down
4 changes: 3 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
torch>=2.0.0
torchvision
cupy-cuda12x
transformers==4.46.2
controlnet-aux==0.0.7
imageio
Expand All @@ -11,3 +10,6 @@ sentencepiece
protobuf
modelscope
ftfy
pillow
numpy
tqdm