Step 1: Define Model & Imports
Import libraries and set up your neural network architecture, the foundation for the upcoming optimizations.
# Step 1: Conceptual Setup and Simple Model Definition
# --- Imports (Conceptual) ---
import torch
import time
# Assuming we have a simple pre-trained model
from torchvision.models import resnet18
# --- Configuration ---
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {DEVICE}")
# --- Load Model ---
# Load a pre-trained ResNet model
model = resnet18(weights=None) # weights=None for simplicity
model.eval() # Set to evaluation mode
# ----------------------------------------------------------------------
# NOTE: The subsequent steps will modify this model or the way data is passed.
Step 2: GPU Acceleration (Moving the Model and Data)
This step moves the model and a single batch of data to the GPU for faster computation. This is a fundamental technique for speed.
# Step 2: GPU Acceleration
# --- Move Model to GPU ---
model.to(DEVICE)
# --- Create Single Input (Conceptual) ---
# Assuming an image tensor (e.g., 3 channels, 224x224 size)
BATCH_SIZE_UNBATCHED = 1
INPUT_TENSOR_SINGLE = torch.randn(BATCH_SIZE_UNBATCHED, 3, 224, 224)
# --- Move Data to GPU ---
INPUT_TENSOR_SINGLE = INPUT_TENSOR_SINGLE.to(DEVICE)
# --- Inference (Baseline Speed) ---
with torch.no_grad():
start_time = time.time()
output_single = model(INPUT_TENSOR_SINGLE)
end_time = time.time()
time_single = end_time - start_time
print(f"\nInference Time (Single, GPU): {time_single:.4f} seconds")
# ----------------------------------------------------------------------
# KEY CONCEPT: By moving computation to the 'cuda' device, we leverage the
# massive parallelism of the GPU.
Step 3: Batching (Processing Multiple Inputs Simultaneously)
Batching significantly improves throughput (samples per second) by allowing the GPU to process many inputs in parallel. While the total time for the batch might be longer than a single input, the time per sample decreases drastically.
# Step 3: Batching
# --- Create Batched Input ---
BATCH_SIZE = 32 # Common batch size for inference
INPUT_TENSOR_BATCHED = torch.randn(BATCH_SIZE, 3, 224, 224)
# --- Move Data to GPU ---
INPUT_TENSOR_BATCHED = INPUT_TENSOR_BATCHED.to(DEVICE)
# --- Batched Inference ---
with torch.no_grad():
start_time = time.time()
# The GPU processes all 32 samples simultaneously
output_batched = model(INPUT_TENSOR_BATCHED)
end_time = time.time()
time_batched = end_time - start_time
# Calculate effective time per sample
time_per_sample_batched = time_batched / BATCH_SIZE
time_per_sample_single = time_single / BATCH_SIZE_UNBATCHED # from Step 2
print(f"\nInference Time (Batch={BATCH_SIZE}, GPU): {time_batched:.4f} seconds")
print(f"Time per Sample (Batch): {time_per_sample_batched:.5f} seconds")
print(f"Time per Sample (Single): {time_per_sample_single:.5f} seconds")
# ----------------------------------------------------------------------
# KEY CONCEPT: Batching exploits GPU parallelization; it reduces overhead
# and keeps the GPU cores busy, boosting throughput.
Step 4: Pruning (Sparsity and Reduced Computation)
Pruning permanently removes "unimportant" connections (weights) from the model, leading to a smaller, sparser model that requires fewer computations. This often requires a specialized library and a retraining step (fine-tuning) to recover accuracy.
# Step 4: Pruning (Conceptual Pseudocode)
# --- Conceptual Pruning Utility ---
# In a real framework, this function would apply a mask to weights
def conceptual_prune_layer(layer, amount=0.5):
"""
Conceptual function to zero out (prune) a percentage of weights.
In reality, we'd use a pruning scheduler (e.g., torch.nn.utils.prune).
"""
if hasattr(layer, 'weight'):
weights = layer.weight.data
# Calculate the threshold for pruning 'amount' percent
threshold = torch.quantile(torch.abs(weights), amount)
# Create a mask: True for weights to keep, False for weights to prune (set to 0)
mask = torch.abs(weights) > threshold
# Apply the mask
weights *= mask.float()
print(f"Pruned {amount*100}% of weights in layer.")
return True
return False
# --- Apply Pruning ---
PRUNING_AMOUNT = 0.60 # Prune 60% of connections
pruned_model = resnet18(weights=None)
pruned_model.eval()
pruned_model.to(DEVICE)
# Iterate through layers and apply conceptual pruning
for name, module in pruned_model.named_modules():
if isinstance(module, torch.nn.Conv2d) or isinstance(module, torch.nn.Linear):
if conceptual_prune_layer(module, amount=PRUNING_AMOUNT):
pass # Pruning applied
# --- Inference with Pruned Model (Conceptual Speedup) ---
# NOTE: The actual speedup from pruning only happens when the
# pruned model is converted to a sparse format or optimized by
# a deployment engine (e.g., ONNX Runtime, TensorRT).
# This standard PyTorch inference might not show a wall-clock speedup
# without further optimization steps (e.g., weight removal/re-indexing).
with torch.no_grad():
start_time_pruned = time.time()
# Use the same batched input
output_pruned = pruned_model(INPUT_TENSOR_BATCHED)
end_time_pruned = time.time()
time_pruned = end_time_pruned - start_time_pruned
print(f"\nInference Time (Pruned, Batch={BATCH_SIZE}): {time_pruned:.4f} seconds")
# ----------------------------------------------------------------------
# KEY CONCEPT: Pruning reduces the **computational complexity** of the model
# by removing unnecessary weights.
