aimlmlxapple-silicon

Porting ERNIE-Image (8B) to Apple’s MLX

What happens when you rebuild Baidu’s 8B text-to-image model for Apple Silicon. A deep look at what’s inside a modern diffusion pipeline.

by Ritesh Khanna|@treadon

Porting a language model to MLX is straightforward. mlx-lm handles most of it. Image generation models are a different story. They have multiple components with different architectures, custom operations, and pipelines that don’t map cleanly to existing MLX tooling.

ERNIE-Image-Turbo is Baidu’s 8B parameter text-to-image model. It’s competitive with FLUX, runs in 8 denoising steps, and it’s Apache 2.0 licensed. It also has four separate components (a text encoder, a diffusion transformer, a VAE, and a scheduler), each requiring its own porting approach. This is what porting a real-world model to Apple Silicon actually looks like.

What’s Inside an Image Model

ERNIE-Image isn’t one model. It’s four models in a pipeline. Each has a different job, a different architecture, and different porting challenges.

Click a component for details

The Text Encoder (Mistral-3, 3.8B) is the most surprising component. Most image models use CLIP (400M) or T5 (11B) for text encoding. Baidu used a full instruction-tuned LLM, the same Mistral architecture that powers chat assistants. But they don’t use it to generate text. They run a forward pass and extract thehidden states from the second-to-last layer: 3072-dimensional vectors that encode the meaning of each token. The model’s output layer (the part that predicts next tokens) is dead weight in this pipeline.

The DiT (Diffusion Transformer, 8B) is where all the compute goes. 36 transformer layers, 4096 hidden size, 32 attention heads. It takes noisy latents and text embeddings, then predicts the “velocity,” the direction to move in latent space to get a cleaner image. This runs 8 times (once per denoising step). Each step processes 4,136 tokens (4,096 image patches + 40 text tokens) through all 36 layers. That’s ~14 trillion FLOPs per step.

The VAE Decoder (84M) converts the 32-channel latent at 64×64 into a 1024×1024 RGB image. Baidu didn’t build this. They borrowed it directly from FLUX.2 by Black Forest Labs. Same architecture, same weights format.

The Scheduler isn’t a neural network. It’s ~40 lines of math that controls the denoising process: how much noise to add, how big each step should be, when to stop.

The Port: Component by Component

The approach: port each component independently, verify weights load (every tensor name must match exactly), test the forward pass, then wire them together. If weights don’t load, the architecture is wrong.

Weight Loading Results
DiT (8B)
409/409PERFECT
VAE (84M)
140/140PERFECT
549 weight tensors matched on first try. Only fix needed: Conv2d transposition (PyTorch NCHW → MLX NHWC).

Both the DiT and VAE loaded on the first attempt: 549 out of 549 weight tensors matched exactly. The only conversion needed was transposing Conv2d weights from PyTorch’s channels-first format to MLX’s channels-last format. This was a good sign that naming our MLX layers to mirror PyTorch’s naming convention was the right approach.

What Didn’t Work

Pure MLX Text Encoder

Reimplementing Mistral-3 from scratch in MLX seemed straightforward. Weights loaded fine, forward pass ran, but the output was wrong (cosine similarity of 0.03 with PyTorch, essentially random). The issue: Mistral uses YaRN RoPE with complex scaling parameters that a naive RoPE implementation doesn’t replicate. Since text encoding is 0.1s out of 134s total, the practical fix is a hybrid approach: PyTorch for text, MLX for everything else. No meaningful speed loss.

Quantization

4-bit and 8-bit quantization made things slower, not faster. The DiT’s matmuls are large ([4136, 4096] × [4096, 12288]) and they’re compute-bound, not memory-bound. Dequantization overhead outweighed bandwidth savings. This is the opposite of LLMs, where single-token matmuls are tiny and memory-bound.

mx.compile()

MLX’s JIT compiler made things 6% slower. The computation graph is too large (36 layers × many ops), and the individual operations (SDPA, large matmuls) are already fused or optimally dispatched. No fusion opportunities left.

What Did Work

Two optimizations gave us a modest but real speedup over PyTorch/MPS:

  • Fused SDPA: mx.fast.scaled_dot_product_attention instead of manual matmul + softmax + matmul
  • Lazy evaluation: no mx.eval() between denoising steps, letting MLX build and optimize the full computation graph before executing

Where the Time Goes

Profiling a single DiT forward pass at 1024×1024 reveals there’s nowhere left to optimize:

Single DiT Step: 15.3 seconds

FFN (49.8%) and Attention (40.9%) account for 90.7% of compute. Everything else is under 12ms combined. There’s no overhead to eliminate.

The FFN is actually the bigger bottleneck, not attention. Each layer’s FFN does a [4136, 4096] × [4096, 12288] matmul (gate + up projections) then a [4136, 12288] × [12288, 4096] matmul (down projection). That’s ~400 billion FLOPs per layer just for the FFN. Times 36 layers, times 8 steps. The GPU is doing actual math, not waiting for memory.

The Results

1024×1024, 8 Steps, M4 Pro 64GB

MLX BF16 is 6% faster than PyTorch/MPS. Quantized versions are slower because this workload is compute-bound, not memory-bound.

A 6% speedup is honest but modest. The real value of a port like this isn’t speed. It’s understanding. Rebuilding each component exposes exactly how a diffusion image model works from the inside: how text becomes vectors, how noise becomes structure, how a VAE reconstructs pixels, and why certain optimizations help while others don’t.

What I Learned

  • MLX and MPS are closer than you’d think for large compute-bound workloads. Both compile to Metal shaders on the same hardware. The gap widens for memory-bound workloads (LLMs) where MLX’s quantization shines.
  • Weight name matching is the real porting work. If your MLX layer names mirror PyTorch exactly, weights load with zero mapping code. The architecture is already defined by the weight shapes.
  • Image diffusion is compute-bound, not memory-bound. The opposite of LLM inference. Quantization helps LLMs because each token only does one matmul. Diffusion does thousands of large matmuls per image.
  • The text encoder matters more than you’d expect. Baidu’s decision to use Mistral-3 instead of CLIP is why the model follows complex instructions. A 262K context window vs 77 tokens is a qualitative difference.
  • Most of the model is borrowed. Baidu built the DiT and fine-tuned the prompt enhancer. The VAE is FLUX.2’s, the text encoder is Mistral-3. Knowing which pieces to assemble is the real skill.

ERNIE-Image in Action

For what it’s worth, the model itself is genuinely impressive, especially at structured layouts, multi-panel compositions, and text rendering. Here are some outputs from the detailed prompts we tested:

4-panel manga comic
4-panel manga comic
Sci-fi movie poster
Retro sci-fi poster
Film photography
Kodak Portra 400 style