Generator Discriminator Adversarial
Latent Space FID Score

Generative Adversarial Networks (GANs)

GANs pit two networks against each other: a Generator that creates fake data and a Discriminator that detects fakes. Through this competition, both improve—generating hyper-realistic images, audio, and beyond.

Latent dim

64-512

Nash Equilibrium

Training goal

Mode Collapse

Key challenge

StyleGAN

1024x1024

The Adversarial Game

GANs are a minimax game between two players: Generator (G) and Discriminator (D). G tries to fool D by generating realistic samples. D tries to distinguish real from fake.

z ~ N(0,1) G(z) D(x) real/fake

D(x) = probability that x is real. G learns to maximize D(G(z)).

Minimax objective: min_G max_D V(D,G) = E_x[log D(x)] + E_z[log(1 - D(G(z)))]

Intuition: G is like a counterfeiter, D is the detective. Competition drives both to perfection.

Vanilla GAN – The Original

Generator

Maps latent vector z to data space. Typically MLP with ReLU + sigmoid/tanh for images.

def build_generator(latent_dim=100):
    model = tf.keras.Sequential([
        tf.keras.layers.Dense(128, input_dim=latent_dim),
        tf.keras.layers.ReLU(),
        tf.keras.layers.BatchNormalization(),
        tf.keras.layers.Dense(784, activation='tanh')  # MNIST
    ])
    return model
Discriminator

Binary classifier. Outputs probability of real image.

def build_discriminator():
    model = tf.keras.Sequential([
        tf.keras.layers.Dense(128, input_shape=(784,)),
        tf.keras.layers.LeakyReLU(alpha=0.2),
        tf.keras.layers.Dropout(0.3),
        tf.keras.layers.Dense(1, activation='sigmoid')
    ])
    return model
PyTorch Vanilla GAN Training Loop
# Training loop (alternating updates)
for epoch in range(epochs):
    for real_imgs, _ in dataloader:
        batch_size = real_imgs.size(0)
        z = torch.randn(batch_size, latent_dim)
        
        # Train Discriminator
        fake_imgs = G(z)
        real_pred = D(real_imgs)
        fake_pred = D(fake_imgs.detach())
        d_loss = -torch.mean(torch.log(real_pred) + torch.log(1 - fake_pred))
        d_loss.backward()
        optimizer_D.step()
        
        # Train Generator
        z = torch.randn(batch_size, latent_dim)
        fake_imgs = G(z)
        fake_pred = D(fake_imgs)
        g_loss = -torch.mean(torch.log(fake_pred))
        g_loss.backward()
        optimizer_G.step()

DCGAN – Convolutional GAN

DCGAN brought CNNs to GANs with key architectural guidelines that stabilized training.

Guidelines
  • Replace pooling with strided conv (D) / fractional conv (G)
  • BatchNorm in both G and D
  • Remove fully connected layers
  • ReLU in G (except output tanh)
  • LeakyReLU in D
# Generator (DCGAN)
class DCGenerator(nn.Module):
    def __init__(self, latent_dim=100):
        super().__init__()
        self.deconv = nn.Sequential(
            nn.ConvTranspose2d(latent_dim, 512, 4, 1, 0, bias=False),
            nn.BatchNorm2d(512), nn.ReLU(True),
            nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256), nn.ReLU(True),
            nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128), nn.ReLU(True),
            nn.ConvTranspose2d(128, 3, 4, 2, 1, bias=False),
            nn.Tanh()
        )
    
    def forward(self, z):
        return self.deconv(z.view(z.size(0), z.size(1), 1, 1))

Training Challenges & Stabilization

⚠️ Mode Collapse: Generator produces limited varieties. D gets stuck in a local optimum.

Solution: WGAN, minibatch discrimination, unrolled GANs.

⚠️ Vanishing Gradients: D becomes too good → G gradient ~0.

Solution: Use WGAN (Earth Mover distance), label smoothing, instance noise.

✅ Label Smoothing: Use 0.9/0.1 instead of 1/0. Prevents overconfidence.
✅ BatchNorm & Spectral Norm: Spectral normalization for D stabilizes training.
✅ Gradient Penalty: WGAN-GP enforces Lipschitz constraint via gradient norm.

Wasserstein GAN (WGAN)

WGAN replaces the binary discriminator with a critic that scores realness. Uses Earth Mover distance, more stable training.

WGAN Loss

D_loss = E[D(fake)] - E[D(real)]
G_loss = -E[D(fake)]

Critic weights clipped to [-c, c] (WGAN) or gradient penalty (WGAN-GP).

WGAN-GP (Gradient Penalty)
def gradient_penalty(critic, real, fake, device):
    batch_size, c, h, w = real.shape
    epsilon = torch.rand(batch_size, 1, 1, 1).repeat(1, c, h, w).to(device)
    interpolated = epsilon * real + (1 - epsilon) * fake
    mixed_score = critic(interpolated)
    
    gradient = torch.autograd.grad(
        inputs=interpolated,
        outputs=mixed_score,
        grad_outputs=torch.ones_like(mixed_score),
        create_graph=True,
        retain_graph=True
    )[0]
    gradient = gradient.view(batch_size, -1)
    gradient_norm = gradient.norm(2, dim=1)
    gp = torch.mean((gradient_norm - 1) ** 2)
    return gp

Conditional GAN (cGAN)

Both generator and discriminator receive additional condition (class label, text, image). Enables controlled generation.

Architecture

Concatenate condition y to z (G) and to x (D).

# Generator
z = torch.randn(batch_size, latent_dim)
y = one_hot(labels)  # condition
gen_input = torch.cat([z, y], dim=1)
fake = G(gen_input)

# Discriminator
dis_input = torch.cat([image, y], dim=1)
score = D(dis_input)
Applications
  • Pix2Pix: Image-to-image translation (edges→photo)
  • CycleGAN: Unpaired translation (horse→zebra)
  • Text-to-Image: Generate images from descriptions
  • SRGAN: Super-resolution

StyleGAN & Progressive GANs

Progressive GAN

Start with low resolution (4x4), add layers as training progresses. Stabilizes high-res generation.

1024x1024 faces, cats, cars.

StyleGAN

Mapping network + AdaIN (adaptive instance normalization). Style mixing enables controllable synthesis (pose, identity, lighting).

Key idea: Noise injects stochastic variation (freckles, hair).

StyleGAN formula: w = MappingNetwork(z) → AdaIN(conv, w) → stochastic variation via noise. Separates high-level attributes from stochastic details.

Evaluating GANs – FID & Inception Score

Inception Score (IS)

Uses ImageNet-pretrained Inception. Measures:

  • High confidence predictions (realistic)
  • Diversity across samples

Criticism: Doesn't detect mode collapse if classes are diverse.

FID (Fréchet Inception Distance)

Compares statistics of real vs fake in Inception feature space.

FID = ||μ_r - μ_f||² + Tr(Σ_r + Σ_f - 2(Σ_rΣ_f)^½)

Lower is better Standard metric today.

# FID using torchmetrics
from torchmetrics.image.fid import FrechetInceptionDistance

fid = FrechetInceptionDistance(feature=2048)
fid.update(real_images, real=True)
fid.update(fake_images, real=False)
print(f"FID: {fid.compute():.2f}")

Production-Ready GAN Implementations

PyTorch Lightning GAN
import pytorch_lightning as pl

class GAN(pl.LightningModule):
    def __init__(self, latent_dim=100):
        super().__init__()
        self.generator = Generator(latent_dim)
        self.discriminator = Discriminator()
        
    def training_step(self, batch, batch_idx, optimizer_idx):
        real_imgs, _ = batch
        z = torch.randn(real_imgs.size(0), self.latent_dim)
        
        if optimizer_idx == 0:  # train D
            fake_imgs = self.generator(z)
            d_loss = self.discriminator_loss(real_imgs, fake_imgs)
            return d_loss
        else:  # train G
            fake_imgs = self.generator(z)
            g_loss = self.generator_loss(fake_imgs)
            return g_loss
Keras TF-GAN
# TensorFlow GAN with custom training
@tf.function
def train_step(real_images):
    z = tf.random.normal([BATCH_SIZE, latent_dim])
    
    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        fake_images = generator(z, training=True)
        
        real_output = discriminator(real_images, training=True)
        fake_output = discriminator(fake_images, training=True)
        
        gen_loss = generator_loss(fake_output)
        disc_loss = discriminator_loss(real_output, fake_output)
    
    gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
    gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)
    
    generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
    discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))

GAN Family Comparison

Model Key Idea Stability Quality Use Case
Vanilla GANMinimax BCE❌ LowEducational
DCGANConvolutional guidelines⭐⭐⭐⭐Small images
WGAN-GPWasserstein + gradient penalty⭐⭐⭐⭐⭐⭐⭐Default stable choice
cGANConditional generation⭐⭐⭐⭐⭐Labeled synthesis
StyleGANStyle modulation + noise⭐⭐⭐⭐⭐⭐⭐⭐High-res faces
CycleGANCycle consistency (unpaired)⭐⭐⭐⭐⭐⭐Unpaired translation

GANs & Responsible AI

⚠️ Deepfakes: GANs can generate realistic fake faces/videos. Critical to develop detection methods and watermarking.
⚠️ Bias amplification: GANs trained on biased datasets amplify stereotypes (e.g., gender, race). Use balanced, diverse data.
✅ Positive impact: Medical imaging (synthetic MRI), drug discovery, data augmentation, creative tools.
✅ Detection: GAN-generated image detection is active research. Use forensic tools.

GAN Cheatsheet

G generator
D discriminator
z latent vector
WGAN Wasserstein
GP gradient penalty
cGAN conditional
FID evaluation
Mode collapse failure