Focal loss (intuition)
Standard cross-entropy is dominated by easy negatives (background). Focal loss multiplies CE by a modulating factor (1 − pt)γ with focusing parameter γ ≥ 0. When the model is confident on a class (pt near 1), the loss shrinks; hard examples keep larger gradients. An optional α balances positive/negative contribution.
# Conceptual focal modulator on top of CE (illustrative)
import math
def focal_weight(pt, gamma=2.0):
return (1.0 - pt) ** gamma
FPN and detection heads
RetinaNet attaches two subnetworks (classification and box regression) at each pyramid level. Anchors span scales and aspect ratios per level. Predictions are decoded to image-space boxes and filtered with thresholding and NMS—same post-processing family as other anchor-based detectors.
vs YOLO
Both one-stage; RetinaNet’s focal loss specifically targets CE imbalance. YOLO families use different assignment and loss formulations.
vs Faster R-CNN
No RPN stage—denser set of candidates; often slower than tiny YOLO but competitive accuracy on COCO-style data.
Inference: torchvision
import torch
import torchvision.transforms as T
from torchvision.models.detection import retinanet_resnet50_fpn
from PIL import Image
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = retinanet_resnet50_fpn(weights="DEFAULT").to(device).eval()
img = Image.open("street.jpg").convert("RGB")
x = T.functional.to_tensor(img).to(device)
with torch.no_grad():
r = model([x])[0]
for i in range(len(r["scores"])):
if r["scores"][i] < 0.5:
continue
box = r["boxes"][i].tolist()
lbl = int(r["labels"][i])
sc = float(r["scores"][i])
Custom classes: replace heads
from torchvision.models.detection.retinanet import RetinaNetClassificationHead
from torchvision.models.detection import retinanet_resnet50_fpn
from torchvision.models.detection.anchor_utils import AnchorGenerator
import torch.nn as nn
num_classes = 3 # e.g. background + 2 object classes
model = retinanet_resnet50_fpn(weights="DEFAULT")
# Typical pattern: rebuild cls head with num_classes and in_channels from backbone
# See torchvision retinanet source for RetinaNetHead constructor args for your version
torchvision’s internal head API shifted across releases—copy the official “Training on a custom dataset” snippet for your installed version.
Takeaways
- Focal loss fights anchor imbalance in dense detectors.
- FPN gives multi-scale representation for small and large objects.
- Strong baseline when you want one-stage accuracy without YOLO-specific tooling.