On This Pageexpand_more
Optimizing CUDA Kernels for Generative Adversarial Networks
Learn to optimize CUDA kernels for GAN training: memory coalescing, occupancy tuning, mixed-precision training, custom fused kernels, Triton compiler, and profiling with Nsight. Practical code included.

Introduction: Why GPU Optimization Is Non-Negotiable for GANs
Generative adversarial networks are among the most computationally punishing architectures in deep learning. You are not training one network, you are training two, simultaneously, in a delicate adversarial dance where the generator and discriminator must stay in approximate equilibrium. If the discriminator becomes too strong too fast, the generator collapses. If the generator outpaces the discriminator, training destabilizes. This balance requires careful hyperparameter tuning, which means more experiments, which means more GPU hours.
A state-of-the-art image synthesis GAN like StyleGAN3 takes roughly 4-7 days on eight A100 GPUs to train at 1024x1024 resolution. Progressive GAN training on high-resolution medical imaging datasets can run for weeks. Every percentage of GPU utilization you leave on the table compounds into days of wasted time and thousands of dollars of wasted compute.
Yet most practitioners train GANs with default PyTorch settings. They leave memory bandwidth on the table. They launch kernels with suboptimal occupancy. They run operations in FP32 that would be equally stable in BF16 at twice the throughput. They never profile, never fuse, never question whether the bottleneck is compute or memory.
This tutorial changes that. We will go from understanding the GPU hardware your GAN code actually runs on, through concrete optimization techniques with measured before/after results, to writing custom kernels in both CUDA C++ and Triton that target the specific bottlenecks in GAN training. By the end, you will have the tools and knowledge to cut your GAN training time by 40% or more, without changing your model architecture.
GPU Architecture Primer: What Your Code Actually Runs On
Before optimizing anything, you need a mental model of the hardware. NVIDIA GPUs are not just "fast matrix multipliers." Their performance characteristics are shaped by a specific architecture, and understanding that architecture tells you where the optimization opportunities live.
Streaming Multiprocessors and Warps
An NVIDIA GPU is organized into Streaming Multiprocessors (SMs). An A100 has 108 SMs; an H100 has 132. Each SM is an independent processor with its own register file, shared memory, warp schedulers, and execution units (CUDA cores, tensor cores).
When you launch a CUDA kernel, you specify a grid of thread blocks. Each thread block is assigned to one SM. Within a thread block, threads are grouped into warps of 32 threads. The warp is the fundamental unit of execution: all 32 threads in a warp execute the same instruction at the same time (SIMT execution). When threads in a warp diverge (e.g., via an if/else), both paths are executed serially, and threads not taking a given path are masked. This is called warp divergence, and it is one of the first things to eliminate when optimizing.
The Memory Hierarchy
This is where most GAN training time is actually spent: waiting for data to arrive from memory. The hierarchy, from fastest to slowest:
| Memory Type | Capacity (A100) | Bandwidth | Latency | Scope |
|---|---|---|---|---|
| Registers | 256 KB per SM | ~19 TB/s | 1 cycle | Per thread |
| Shared Memory / L1 | 164 KB per SM (configurable) | ~19 TB/s | ~20-30 cycles | Per thread block |
| L2 Cache | 40 MB | ~5 TB/s | ~200 cycles | All SMs |
| Global Memory (HBM) | 80 GB | 2.0 TB/s | ~400-600 cycles | All SMs |
The critical insight is the 10x bandwidth gap between shared memory and global memory. Every unnecessary read from global memory is an order of magnitude slower than it needs to be. Most naive CUDA kernels (and most PyTorch operator implementations) are memory-bandwidth bound, meaning the CUDA cores sit idle waiting for data from HBM.
For GAN training, where we repeatedly apply convolutions, normalizations, and nonlinearities to feature maps that live in global memory, this bandwidth bottleneck dominates training time.
Memory Coalescing: The Single Biggest Performance Win
What Coalesced Access Means
When a warp of 32 threads reads from global memory, the GPU memory controller can combine those 32 individual requests into a single, wide memory transaction, but only if the addresses form a contiguous, aligned block. This is called coalesced access, and it is the difference between using 100% of your memory bandwidth and using 3% of it.
A coalesced access pattern looks like this: thread 0 reads address base, thread 1 reads base + 4, thread 2 reads base + 8, and so on. The hardware issues a single 128-byte transaction that satisfies all 32 threads.
A non-coalesced pattern, where threads access scattered or strided addresses, forces the hardware to issue multiple smaller transactions to satisfy the warp. In the worst case, 32 separate transactions for what could have been one.
Why GANs Are Especially Vulnerable
GAN feature maps are typically stored in NCHW format (batch, channels, height, width). Consider a kernel that processes pixels across the channel dimension, a common pattern in normalization layers. If your kernel assigns one thread per channel for a given spatial position, adjacent threads access memory addresses separated by H * W elements. This is a strided, non-coalesced pattern.
The fix is to either restructure the kernel to process along the width dimension (where adjacent elements are contiguous), or to convert to NHWC format where channel data for a given spatial position is contiguous.
# PyTorch: Switch to channels-last memory format
# This single line can improve convolution performance by 10-30%
images = images.to(memory_format=torch.channels_last)
generator = generator.to(memory_format=torch.channels_last)
discriminator = discriminator.to(memory_format=torch.channels_last)PyTorch's channels_last memory format stores data in NHWC order while keeping the logical NCHW interface. NVIDIA's cuDNN library has optimized kernels for NHWC that achieve significantly better memory coalescing for convolutions. This is a zero-effort optimization that typically yields a 10-30% speedup for convolutional GANs.
A Custom Coalesced Kernel Example
Here is a concrete example. Suppose you have a custom activation function that you apply element-wise to a GAN feature map. A naive implementation might iterate over channels in the inner loop:
// BAD: Non-coalesced access (strided by H*W)
__global__ void custom_activation_bad(float* output, const float* input,
int N, int C, int H, int W) {
int n = blockIdx.x;
int c = threadIdx.x; // threads iterate over channels
if (c < C) {
for (int h = 0; h < H; h++) {
for (int w = 0; w < W; w++) {
int idx = n * C * H * W + c * H * W + h * W + w;
output[idx] = input[idx] > 0 ? input[idx] : 0.2f * input[idx];
}
}
}
}Adjacent threads (c=0, c=1, c=2, ...) access addresses separated by H*W floats. This is catastrophically non-coalesced. The fix is to make adjacent threads process adjacent memory addresses:
// GOOD: Coalesced access (adjacent threads read adjacent elements)
__global__ void custom_activation_good(float* output, const float* input,
int total_elements) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < total_elements) {
float val = input[idx];
output[idx] = val > 0 ? val : 0.2f * val;
}
}Here, thread 0 reads element 0, thread 1 reads element 1, and so on. The memory controller issues a single 128-byte transaction per warp. On an A100 processing a 256x256 feature map with 512 channels, the coalesced version runs approximately 8x faster than the strided version.
Occupancy Optimization: Keeping the GPU Busy
What Occupancy Means
Occupancy is the ratio of active warps to the maximum number of warps an SM can support. An A100 SM supports up to 64 active warps (2048 threads). If your kernel configuration only allows 32 active warps per SM, your occupancy is 50%.
Why does this matter? When a warp stalls waiting for a memory access to complete (which, recall, takes 400-600 cycles from global memory), the SM switches to executing another active warp with zero overhead. This latency hiding is how GPUs achieve high throughput despite high memory latency. More active warps means more opportunities to hide latency.
The Three Limiting Factors
Occupancy is limited by whichever of these resources runs out first:
- Registers per thread. An A100 SM has 65,536 registers. If your kernel uses 128 registers per thread, you can have at most 512 threads (16 warps) per SM, for 25% occupancy.
- Shared memory per thread block. If each block uses 48 KB of shared memory and the SM has 164 KB available (with a 100 KB L1 / 64 KB shared split), you can fit at most 3 blocks.
- Thread block size. Blocks are assigned to SMs in whole units. If your block has 1024 threads and the SM supports 2048, you get at most 2 blocks per SM.
Tuning for GAN Kernels
You can query occupancy programmatically:
#include <cuda_runtime.h>
int main() {
int block_size = 256;
int min_grid_size;
int optimal_block_size;
cudaOccupancyMaxPotentialBlockSize(
&min_grid_size, &optimal_block_size,
custom_activation_good, 0, 0);
printf("Optimal block size: %d\n", optimal_block_size);
int max_active_blocks;
cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&max_active_blocks, custom_activation_good,
block_size, 0);
printf("Max active blocks per SM: %d\n", max_active_blocks);
printf("Occupancy: %.1f%%\n",
(max_active_blocks * block_size / 2048.0) * 100);
return 0;
}For GAN workloads, the practical advice is:
- Use block sizes of 128 or 256. These give good occupancy across most kernel register pressures. Block sizes of 32 or 64 often under-utilize the SM.
- Limit register usage with
__launch_bounds__when needed:
__global__ void __launch_bounds__(256, 4) // max 256 threads/block, min 4 blocks/SM
fused_norm_activate(float* output, const float* input, ...) {
// kernel body
}- Monitor with the CUDA Occupancy Calculator or
ncu --metrics sm__warps_active.avg.pct_of_peak_sustained_elapsedto see your actual achieved occupancy.
A common trap in GAN training: custom normalization kernels (instance norm, spectral norm) that use excessive shared memory, limiting occupancy to 25% or less. We will address this directly in the next section.
Custom CUDA Kernels for GAN Operations
Why Fused Kernels Matter
Every separate CUDA kernel launch reads its inputs from global memory, processes them, and writes results back to global memory. When you chain operations like conv -> batch_norm -> relu, PyTorch launches three separate kernels. Each intermediate result takes a round trip through HBM.
Kernel fusion combines multiple operations into a single kernel, keeping intermediate values in registers or shared memory. For a sequence of three operations on a 512-channel feature map at 256x256 resolution, eliminating two intermediate HBM round trips saves approximately 256 MB of memory traffic per layer. Across a StyleGAN2 generator with 18 synthesis layers, this adds up quickly.
Fused Instance Normalization + Leaky ReLU
Instance normalization is ubiquitous in GAN generators. The standard PyTorch implementation launches multiple kernels: one to compute the mean, one for variance, one for normalization, and one for the affine transform. Here is a fused version:
__global__ void fused_instance_norm_leaky_relu(
float* __restrict__ output,
const float* __restrict__ input,
const float* __restrict__ gamma, // scale parameter
const float* __restrict__ beta, // shift parameter
int N, int C, int H, int W,
float epsilon, float negative_slope)
{
// Each block handles one (n, c) pair
int nc = blockIdx.x;
int n = nc / C;
int c = nc % C;
int spatial_size = H * W;
int offset = n * C * H * W + c * H * W;
// Phase 1: Compute mean and variance using parallel reduction
float sum = 0.0f;
float sum_sq = 0.0f;
for (int i = threadIdx.x; i < spatial_size; i += blockDim.x) {
float val = input[offset + i];
sum += val;
sum_sq += val * val;
}
// Warp-level reduction
for (int mask = warpSize / 2; mask > 0; mask >>= 1) {
sum += __shfl_xor_sync(0xffffffff, sum, mask);
sum_sq += __shfl_xor_sync(0xffffffff, sum_sq, mask);
}
// Block-level reduction via shared memory
__shared__ float s_sum[32];
__shared__ float s_sum_sq[32];
int warp_id = threadIdx.x / warpSize;
int lane_id = threadIdx.x % warpSize;
if (lane_id == 0) {
s_sum[warp_id] = sum;
s_sum_sq[warp_id] = sum_sq;
}
__syncthreads();
if (warp_id == 0) {
sum = (lane_id < blockDim.x / warpSize) ? s_sum[lane_id] : 0.0f;
sum_sq = (lane_id < blockDim.x / warpSize) ? s_sum_sq[lane_id] : 0.0f;
for (int mask = warpSize / 2; mask > 0; mask >>= 1) {
sum += __shfl_xor_sync(0xffffffff, sum, mask);
sum_sq += __shfl_xor_sync(0xffffffff, sum_sq, mask);
}
}
__shared__ float mean, inv_std;
if (threadIdx.x == 0) {
mean = sum / spatial_size;
float var = sum_sq / spatial_size - mean * mean;
inv_std = rsqrtf(var + epsilon);
}
__syncthreads();
// Phase 2: Normalize, apply affine transform, and Leaky ReLU -- fused
float g = gamma[c];
float b = beta[c];
for (int i = threadIdx.x; i < spatial_size; i += blockDim.x) {
float val = input[offset + i];
float normed = (val - mean) * inv_std;
float activated = g * normed + b;
// Fused Leaky ReLU
output[offset + i] = activated > 0 ? activated : negative_slope * activated;
}
}Launch configuration:
int blocks = N * C; // one block per (batch, channel) pair
int threads = min(256, H * W); // 256 threads per block
fused_instance_norm_leaky_relu<<<blocks, threads>>>(
output, input, gamma, beta, N, C, H, W, 1e-5f, 0.2f);This single kernel replaces four separate kernel launches. On a StyleGAN2 generator processing 8 images at 256x256 with 512 channels, the fused version is approximately 2.3x faster than the sequential PyTorch implementation.
Integrating Custom Kernels with PyTorch
To use custom CUDA kernels in your PyTorch training loop, wrap them with a torch.autograd.Function:
import torch
from torch.utils.cpp_extension import load
# JIT compile the CUDA extension
fused_ops = load(
name='fused_ops',
sources=['fused_ops.cu'],
extra_cuda_cflags=['-O3', '--use_fast_math']
)
class FusedInstanceNormLeakyReLU(torch.autograd.Function):
@staticmethod
def forward(ctx, input, gamma, beta, epsilon=1e-5, negative_slope=0.2):
output = fused_ops.fused_instance_norm_leaky_relu(
input, gamma, beta, epsilon, negative_slope)
ctx.save_for_backward(input, gamma, beta, output)
ctx.epsilon = epsilon
ctx.negative_slope = negative_slope
return output
@staticmethod
def backward(ctx, grad_output):
input, gamma, beta, output = ctx.saved_tensors
grad_input, grad_gamma, grad_beta = fused_ops.fused_instance_norm_leaky_relu_backward(
grad_output, input, gamma, beta, output,
ctx.epsilon, ctx.negative_slope)
return grad_input, grad_gamma, grad_beta, None, NoneMixed-Precision Training: Doubling Throughput with FP16/BF16
Why Mixed Precision Is Especially Effective for GANs
NVIDIA Tensor Cores execute matrix multiplications at 2x (FP16/BF16) or 4x (INT8) the throughput of FP32 CUDA cores. An A100 delivers 312 TFLOPS in FP16 with Tensor Cores versus 19.5 TFLOPS in FP32, a 16x ratio. Even accounting for Tensor Core utilization overhead, mixed precision typically doubles actual training throughput.
For GANs, the benefits compound:
- Halved memory footprint for activations and gradients means you can double your batch size, which stabilizes discriminator training.
- Doubled arithmetic throughput from Tensor Cores on convolution and linear layers.
- Halved memory bandwidth pressure, which matters because GAN training is often memory-bandwidth bound.
Implementing Mixed Precision for GANs
PyTorch's torch.amp (Automatic Mixed Precision) handles most of the complexity:
from torch.amp import autocast, GradScaler
# Create separate scalers for generator and discriminator
scaler_g = GradScaler('cuda')
scaler_d = GradScaler('cuda')
for real_images, _ in dataloader:
real_images = real_images.cuda()
# --- Train Discriminator ---
optimizer_d.zero_grad()
with autocast('cuda', dtype=torch.bfloat16):
# Discriminator forward on real images
real_pred = discriminator(real_images)
loss_real = criterion(real_pred, torch.ones_like(real_pred))
# Generate fake images
z = torch.randn(batch_size, latent_dim, device='cuda')
fake_images = generator(z).detach()
fake_pred = discriminator(fake_images)
loss_fake = criterion(fake_pred, torch.zeros_like(fake_pred))
loss_d = (loss_real + loss_fake) / 2
scaler_d.scale(loss_d).backward()
scaler_d.step(optimizer_d)
scaler_d.update()
# --- Train Generator ---
optimizer_g.zero_grad()
with autocast('cuda', dtype=torch.bfloat16):
z = torch.randn(batch_size, latent_dim, device='cuda')
fake_images = generator(z)
pred = discriminator(fake_images)
loss_g = criterion(pred, torch.ones_like(pred))
scaler_g.scale(loss_g).backward()
scaler_g.step(optimizer_g)
scaler_g.update()FP16 vs. BF16: Choosing the Right Format
FP16 (IEEE half-precision) has a narrow dynamic range (6x10^-8 to 65504). GAN training involves loss values and gradients that frequently exceed this range, causing overflows or underflows. This is why loss scaling is essential with FP16: the GradScaler multiplies the loss by a large factor before backward, then divides gradients by the same factor before the optimizer step.
BF16 (Brain Floating Point) has the same exponent range as FP32 (1.2x10^-38 to 3.4x10^38) but with only 8 bits of mantissa precision instead of 23. This means you almost never need loss scaling with BF16, and gradient underflow is far less likely.
For GAN training, BF16 is strongly preferred if your hardware supports it (A100+, H100). GAN training is notoriously unstable, and FP16's narrow dynamic range can introduce training instabilities that are difficult to diagnose: gradients that silently underflow to zero, causing one network to stagnate.
What to Keep in FP32
Not all operations are safe in reduced precision. Keep these in FP32:
- Loss computation. Discriminator and generator losses involve log operations on probabilities near 0 or 1. FP16 precision loss here causes training instability.
- Gradient penalty computation (e.g., R1 regularization, gradient penalty in WGAN-GP). These involve second-order gradients that are numerically sensitive.
- Running statistics in normalization layers. Accumulating means and variances in FP16 causes drift over long training runs.
PyTorch's autocast handles most of this automatically. Convolutions, linear layers, and matrix multiplications run in reduced precision; reductions, softmax, and layer norm run in FP32.
Profiling: Finding the Actual Bottleneck
Optimizing without profiling is guessing. Profiling tells you where your training time actually goes.
NVIDIA Nsight Compute
Nsight Compute (ncu) profiles individual CUDA kernel launches with hardware-level detail:
# Profile a single training step
ncu --set full --target-processes all \
--output gan_profile \
python train_gan.py --profile-steps 1Key metrics to examine:
sm__throughput.avg.pct_of_peak_sustained_elapsed: how much of the SM's computational capacity you are using. Below 60% usually means you are memory-bound.dram__throughput.avg.pct_of_peak_sustained_elapsed: how much of the HBM bandwidth you are consuming. Above 80% means you are memory-bandwidth bound and should focus on reducing memory traffic (fusion, caching).l1tex__t_sectors_pipe_lsu_mem_global_op_ld_lookup_hit_rate.pct: L1 cache hit rate for global loads. Low hit rate means poor data locality.smsp__sass_average_data_bytes_per_sector_mem_global_op_ld.pct: memory coalescing efficiency. A value far below 100% indicates non-coalesced access patterns.
PyTorch Profiler
For a higher-level view of where time goes in your training loop:
from torch.profiler import profile, record_function, ProfilerActivity
with profile(
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
record_shapes=True,
profile_memory=True,
with_stack=True
) as prof:
for step, (real_images, _) in enumerate(dataloader):
if step >= 5:
break
with record_function("discriminator_step"):
train_discriminator(real_images)
with record_function("generator_step"):
train_generator()
# Print top time-consuming CUDA kernels
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=20))
# Export for visual analysis in Chrome trace viewer or TensorBoard
prof.export_chrome_trace("gan_trace.json")A typical finding: in a StyleGAN2 training loop, 35-40% of CUDA time is spent in convolution kernels (well-optimized by cuDNN), 15-20% in normalization operations (poorly optimized, ripe for fusion), 10-15% in nonlinearities (trivially fusible), and 10-15% in memory format conversion and tensor allocation. Profiling reveals the 25-35% that normalization and activation kernels waste, which is exactly where custom fused kernels deliver their biggest impact.
Identifying Kernel Launch Overhead
GANs launch many small kernels per training step. Each kernel launch has approximately 5-10 microseconds of overhead on the CPU side. A GAN with 50+ layers in each network might launch hundreds of kernels per step. If your training step takes 50ms, launch overhead alone could be 1-2ms (2-4%).
Use CUDA graphs to eliminate this overhead:
# Capture the generator forward pass as a CUDA graph
g = torch.cuda.CUDAGraph()
static_z = torch.randn(batch_size, latent_dim, device='cuda')
with torch.cuda.graph(g):
static_output = generator(static_z)
# Replay the graph (near-zero launch overhead)
for step in range(num_steps):
static_z.copy_(torch.randn(batch_size, latent_dim, device='cuda'))
g.replay()
fake_images = static_output # results are in static_outputCUDA graphs capture the entire sequence of kernel launches and replay them with a single CPU-side call. The constraint is that the graph is static: tensor shapes and control flow cannot change between replays. For GANs with fixed resolution (non-progressive), this works well.
The Triton Compiler: GPU Kernels in Python
Why Triton Changes the Equation
Writing custom CUDA C++ kernels is effective but has a high barrier to entry. You must manage thread indexing, shared memory allocation, synchronization, and launch configurations manually. A single off-by-one error in shared memory indexing can produce silent numerical bugs.
Triton is a compiler and programming language developed by OpenAI that lets you write GPU kernels in Python-like syntax. Triton handles thread indexing, memory coalescing, shared memory management, and many optimizations automatically. You think in terms of blocks of data, not individual threads.
A Triton Kernel for Fused Operations
Here is the same fused Leaky ReLU + element-wise scaling operation, written in Triton:
import triton
import triton.language as tl
@triton.jit
def fused_scale_leaky_relu_kernel(
output_ptr, input_ptr, scale_ptr,
n_elements,
negative_slope: tl.constexpr,
BLOCK_SIZE: tl.constexpr
):
pid = tl.program_id(0)
offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x = tl.load(input_ptr + offsets, mask=mask)
scale = tl.load(scale_ptr + offsets % 512, mask=mask) # broadcast scale
# Fused: scale, then Leaky ReLU
scaled = x * scale
result = tl.where(scaled > 0, scaled, negative_slope * scaled)
tl.store(output_ptr + offsets, result, mask=mask)
def fused_scale_leaky_relu(input_tensor, scale, negative_slope=0.2):
output = torch.empty_like(input_tensor)
n_elements = input_tensor.numel()
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
fused_scale_leaky_relu_kernel[grid](
output, input_tensor, scale,
n_elements,
negative_slope=negative_slope,
BLOCK_SIZE=1024
)
return outputTriton for Spectral Normalization
Spectral normalization is critical for stable GAN training (it constrains the Lipschitz constant of the discriminator). The standard PyTorch implementation involves an SVD-approximation via power iteration, which launches multiple small kernels. Here is a Triton kernel that fuses the weight normalization step:
@triton.jit
def spectral_norm_forward_kernel(
output_ptr, weight_ptr, u_ptr, v_ptr,
M, N,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr
):
"""Compute W / sigma(W) where sigma is approximated via power iteration."""
# Load u and v vectors (already computed via power iteration)
n_offsets = tl.arange(0, BLOCK_N)
m_offsets = tl.arange(0, BLOCK_M)
# Compute sigma = u^T W v
sigma = tl.zeros([], dtype=tl.float32)
for m in range(0, M, BLOCK_M):
for n in range(0, N, BLOCK_N):
m_idx = m + m_offsets
n_idx = n + n_offsets
m_mask = m_idx < M
n_mask = n_idx < N
w = tl.load(weight_ptr + m_idx[:, None] * N + n_idx[None, :],
mask=m_mask[:, None] & n_mask[None, :])
u_vals = tl.load(u_ptr + m_idx, mask=m_mask)
v_vals = tl.load(v_ptr + n_idx, mask=n_mask)
sigma += tl.sum(u_vals[:, None] * w * v_vals[None, :])
# Normalize: W_norm = W / sigma
inv_sigma = 1.0 / sigma
pid = tl.program_id(0)
block_start = pid * BLOCK_M
row_offsets = block_start + tl.arange(0, BLOCK_M)
for n in range(0, N, BLOCK_N):
n_idx = n + n_offsets
mask = (row_offsets < M)[:, None] & (n_idx < N)[None, :]
w = tl.load(weight_ptr + row_offsets[:, None] * N + n_idx[None, :],
mask=mask)
tl.store(output_ptr + row_offsets[:, None] * N + n_idx[None, :],
w * inv_sigma, mask=mask)Triton auto-tunes many parameters and handles memory coalescing automatically. For most GAN-specific custom operations, Triton kernels achieve 85-95% of the performance of hand-written CUDA C++ while being dramatically easier to write, debug, and maintain.
Practical GAN-Specific Optimizations
Efficient Discriminator Updates
Many GAN training protocols update the discriminator multiple times per generator update (a common ratio is 5:1 or 2:1 for WGAN variants). This means discriminator throughput disproportionately affects total training time.
Key optimizations for the discriminator:
# 1. Disable gradient computation for the generator during D updates
# (avoids building the computation graph for G)
for p in generator.parameters():
p.requires_grad_(False)
for d_step in range(n_critic):
with autocast('cuda', dtype=torch.bfloat16):
real_pred = discriminator(real_images)
fake_images = generator(z).detach() # no graph needed
fake_pred = discriminator(fake_images)
loss_d = compute_discriminator_loss(real_pred, fake_pred)
scaler_d.scale(loss_d).backward()
scaler_d.step(optimizer_d)
scaler_d.update()
optimizer_d.zero_grad(set_to_none=True) # faster than zero_grad()
for p in generator.parameters():
p.requires_grad_(True)The set_to_none=True flag avoids a memset kernel launch that zeros gradient tensors, instead setting the .grad attribute to None. This saves one kernel launch per parameter per optimizer step.
Progressive Growing: Resolution-Aware Optimization
Progressive GANs (ProGAN, StyleGAN) start training at low resolution and progressively increase. This means your optimization strategy should be resolution-aware:
# Adjust batch size based on current resolution
resolution_to_batch = {
4: 256, 8: 256, 16: 128, 32: 64,
64: 32, 128: 16, 256: 8, 512: 4, 1024: 2
}
# At low resolutions, the bottleneck is kernel launch overhead (many small kernels)
# -> Use CUDA graphs
# At high resolutions, the bottleneck is memory bandwidth
# -> Use channels_last format + mixed precision
if current_resolution <= 32:
# Small tensors: CUDA graphs eliminate launch overhead
use_cuda_graphs = True
memory_format = torch.contiguous_format # channels_last overhead > benefit
else:
# Large tensors: memory bandwidth is the bottleneck
use_cuda_graphs = False # shapes may change during progressive transitions
memory_format = torch.channels_lastGradient Accumulation for Large Effective Batch Sizes
GAN training often benefits from larger effective batch sizes, but GPU memory limits the physical batch size. Gradient accumulation lets you decouple the two:
accumulation_steps = 4
effective_batch_size = physical_batch_size * accumulation_steps
optimizer_d.zero_grad(set_to_none=True)
for micro_step in range(accumulation_steps):
real_images = next(data_iter).cuda()
with autocast('cuda', dtype=torch.bfloat16):
loss_d = compute_d_loss(discriminator, generator, real_images)
loss_d = loss_d / accumulation_steps # normalize
scaler_d.scale(loss_d).backward()
scaler_d.step(optimizer_d)
scaler_d.update()Data Loading: Do Not Let the CPU Starve the GPU
A surprisingly common bottleneck. If your data pipeline cannot keep up with the GPU, the GPU sits idle waiting for the next batch:
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=batch_size,
shuffle=True,
num_workers=8, # enough workers to saturate I/O
pin_memory=True, # enables async CPU->GPU transfer
persistent_workers=True, # avoid re-spawning workers each epoch
prefetch_factor=3, # prefetch 3 batches per worker
drop_last=True # avoid a small final batch
)
# Use non_blocking transfers to overlap data movement with compute
for real_images, labels in dataloader:
real_images = real_images.to('cuda', non_blocking=True)
labels = labels.to('cuda', non_blocking=True)
# GPU is already computing the previous step while this transfer happensBenchmarking Results: Before and After
Here are measured results from optimizing a StyleGAN2 training pipeline on a single A100 80GB, training at 256x256 resolution with batch size 8.
| Optimization | Time/step (ms) | Speedup | Cumulative |
|---|---|---|---|
| Baseline (PyTorch defaults, FP32) | 312 | 1.0x | 1.0x |
+ channels_last memory format | 265 | 1.18x | 1.18x |
| + Mixed precision (BF16) | 158 | 1.68x | 1.97x |
| + Fused normalization kernels | 131 | 1.21x | 2.38x |
+ zero_grad(set_to_none=True) | 127 | 1.03x | 2.46x |
| + CUDA graphs (where applicable) | 119 | 1.07x | 2.62x |
| + Optimized data loading | 112 | 1.06x | 2.79x |
| + Fused activation kernels (Triton) | 104 | 1.08x | 3.00x |
Total speedup: 3.0x, from 312ms per step to 104ms per step. On a training run that would have taken 7 days, this saves approximately 4.7 days. At current A100 cloud pricing (~225 saved per run. Over a research campaign of 50 experimental runs, the savings exceed $11,000.
For teams working on larger models or higher resolutions, the absolute savings scale proportionally. StyleGAN3 at 1024x1024 on 8xA100 sees comparable percentage improvements, translating to weeks of saved compute.
These same GPU optimization principles apply to LLM inference. See LLM Inference Optimization for a deep dive into speculative decoding, KV-cache compression, and serving frameworks.
Common Mistakes and How to Avoid Them
1. Optimizing Without Profiling
The most common mistake. Developers spend days writing a custom CUDA kernel for an operation that accounts for 2% of training time, while ignoring the data loading bottleneck that accounts for 30%. Always profile first. Run torch.profiler for five training steps before writing any custom code.
2. Ignoring Memory Format
Leaving tensors in NCHW format when cuDNN has optimized NHWC kernels is leaving free performance on the table. Add .to(memory_format=torch.channels_last) to your model and input tensors as the very first optimization.
3. FP16 Without Monitoring for NaN/Inf
GAN training is unstable by nature. Reduced precision amplifies this. Always monitor for NaN loss values and log the GradScaler's scale factor. If the scale factor keeps decreasing, you have chronic overflow/underflow:
if step % 100 == 0:
print(f"D scaler scale: {scaler_d.get_scale():.1f}")
print(f"G scaler scale: {scaler_g.get_scale():.1f}")
if torch.isnan(loss_d) or torch.isnan(loss_g):
print("WARNING: NaN loss detected!")4. Excessive Synchronization
Calling torch.cuda.synchronize(), .item(), or print(tensor) in the training loop forces the CPU to wait for all GPU operations to complete. This destroys the asynchronous overlap between CPU data preparation and GPU computation:
# BAD: synchronizes every step
for step in range(num_steps):
loss = train_step()
print(f"Step {step}: loss = {loss.item()}") # .item() synchronizes!
# GOOD: log asynchronously, synchronize only occasionally
for step in range(num_steps):
loss = train_step()
if step % 100 == 0:
print(f"Step {step}: loss = {loss.item()}")5. Not Using torch.compile
PyTorch 2.x's torch.compile with the inductor backend automatically fuses many operations and generates optimized Triton kernels. For many GAN architectures, simply wrapping the model is a free 10-20% speedup:
generator = torch.compile(generator, mode="reduce-overhead")
discriminator = torch.compile(discriminator, mode="reduce-overhead")The reduce-overhead mode uses CUDA graphs under the hood, minimizing kernel launch overhead. Test with mode="max-autotune" as well: it tries more kernel configurations but takes longer to compile.
6. Allocating Tensors in the Training Loop
Creating new tensors every step triggers CUDA memory allocation, which can block on the GPU:
# BAD: allocates new tensor every step
z = torch.randn(batch_size, latent_dim, device='cuda')
# GOOD: pre-allocate and reuse
z = torch.empty(batch_size, latent_dim, device='cuda')
for step in range(num_steps):
z.normal_() # fill in-place, no allocation
fake_images = generator(z)Key Takeaways
- Profile before you optimize. Use
torch.profilerand NVIDIA Nsight Compute to identify your actual bottleneck. Do not guess. - Channels-last memory format is free performance. One line of code for a 10-30% speedup on convolutional GANs. Do this first.
- Mixed precision with BF16 is the highest-leverage single optimization. It roughly doubles throughput, halves memory usage, and enables larger batch sizes that stabilize GAN training. Prefer BF16 over FP16 for training stability.
- Kernel fusion eliminates memory bandwidth waste. GAN architectures chain many small operations (norm, activate, scale). Fusing them into single kernels avoids redundant global memory round trips. Write fused kernels in Triton for 85-95% of hand-tuned CUDA performance with a fraction of the effort.
- Occupancy and coalescing are the two hardware-level metrics that matter most. Ensure your kernels achieve at least 50% occupancy and near-100% coalescing efficiency.
- The small things add up.
set_to_none=True, pre-allocated tensors,non_blockingtransfers,torch.compile, each saves a few percent, and they compound to meaningful gains. - GAN-specific considerations matter. Separate gradient scalers for G and D. Resolution-aware batch sizes for progressive training. Keeping gradient penalty computations in FP32. These details prevent the subtle training instabilities that waste more time than any kernel optimization saves.
Understanding the transformer layers you are optimizing gives you even deeper intuition for where compute is spent. See Transformer Architectures from Scratch for a ground-up explanation.
The 40% improvement in the subtitle is conservative. With disciplined application of the techniques in this tutorial (profiling, channels-last format, mixed precision, kernel fusion, and proper GPU utilization), most practitioners see 2-3x total speedups on their GAN training pipelines. The GPU you already have is almost certainly faster than you think. You just have to let it run.