Depthwise separable (from scratch sketch)
import torch.nn as nn
class DepthwiseSeparable(nn.Module):
def __init__(self, in_ch, out_ch, stride=1):
super().__init__()
self.dw = nn.Conv2d(in_ch, in_ch, 3, stride, 1, groups=in_ch, bias=False)
self.pw = nn.Conv2d(in_ch, out_ch, 1, bias=False)
self.bn = nn.BatchNorm2d(out_ch)
self.act = nn.ReLU6(inplace=True)
def forward(self, x):
return self.act(self.bn(self.pw(self.dw(x))))
Real MobileNet blocks add expansion ratios, residuals (V2), and SE/h-swish (V3)—use torchvision.models for faithful implementations.
torchvision MobileNetV2 / V3
from torchvision.models import (
mobilenet_v2, mobilenet_v3_small, mobilenet_v3_large,
MobileNet_V2_Weights, MobileNet_V3_Small_Weights, MobileNet_V3_Large_Weights,
)
m2 = mobilenet_v2(weights=MobileNet_V2_Weights.IMAGENET1K_V1).eval()
m3s = mobilenet_v3_small(weights=MobileNet_V3_Small_Weights.IMAGENET1K_V1).eval()
tf2 = MobileNet_V2_Weights.IMAGENET1K_V1.transforms()
tf3 = MobileNet_V3_Small_Weights.IMAGENET1K_V1.transforms()
Classification
from PIL import Image
import torch
img = tf2(Image.open("cat.jpg").convert("RGB")).unsqueeze(0)
with torch.no_grad():
logits = m2(img)
idx = int(logits.argmax(1))
print(MobileNet_V2_Weights.IMAGENET1K_V1.meta["categories"][idx])
Feature vector (before classifier)
# MobileNetV2: features end before classifier
feat_net = nn.Sequential(m2.features, nn.AdaptiveAvgPool2d(1), nn.Flatten(1))
with torch.no_grad():
emb = feat_net(img)
print(emb.shape)
V3 small: same idea
img3 = tf3(Image.open("cat.jpg").convert("RGB")).unsqueeze(0)
feat3 = nn.Sequential(m3s.features, m3s.avgpool, nn.Flatten(1))
with torch.no_grad():
e3 = feat3(img3)
Replace classifier head
num_classes = 10
in_f = m2.classifier[1].in_features
m2.classifier[1] = nn.Linear(in_f, num_classes)
Width multiplier & resolution
Papers scale channel width and input resolution for accuracy–latency tradeoffs. In torchvision, pick a different weights enum or instantiate without pretrained weights and pass width_mult where the API exposes it (API varies by version).
Takeaways
- Depthwise + pointwise ≈ fewer FLOPs than one full conv.
- V2: inverted residual + linear bottleneck.
- V3: tuned for mobile with h-swish / SE-style blocks.