The training loop is where PyTorch gives you full control. You decide how batches move to the device, how loss is computed, when gradients are cleared, whether gradients are clipped, how metrics are tracked, and when checkpoints are saved.
A clean loop separates training and validation. Training uses `model.train()` and gradients. Validation uses `model.eval()` and `torch.no_grad()`. Mixing those modes is a common source of unreliable metrics.
PyTorch is expanded here with a practical explanation, multiple examples, and beginner-focused checks so the idea is easier to learn from this page alone.
Read the concept first, then trace the example line by line. The important habit is to connect the rule to visible behavior instead of memorizing only the name.
A training loop repeatedly predicts, measures error, backpropagates gradients, and updates parameters, while validation measures generalization without updates.
Each epoch processes every training batch, then evaluates on validation data. Metrics should be averaged by the number of examples, not by a naive number of batches when batch sizes vary.
PyTorch becomes much easier when you separate the concept from the tool syntax. First identify the problem being solved, then identify the data or resource being changed, and finally identify the proof that the change worked.
In PyTorch, this topic should be studied through tensor shape, dtype, device, gradient flow, loss movement, and reproducibility. Those points explain not only how to use the feature, but also why it fails when the wrong assumption is made.
The previous audit note was: under 650 content words; fewer than 2 sections . This expanded section adds a fuller explanation, concrete examples, and practice guidance so the page can stand on its own for beginners.
A good way to learn this page is to read the normal path once, run or trace the example, then intentionally change one input to observe the different result. That one change teaches more than memorizing several definitions.
Start with a tiny project scenario. For example, imagine one user action, one request, one resource, one function call, or one batch of data. Keep the scenario small enough that every step can be explained without skipping details.
Next, describe the movement of information. Where does the input start? Which rule or component handles it? What result should appear? If the result is wrong, where would you inspect first?
Finally, compare two outcomes. The correct outcome proves that you understand the main rule. The incorrect outcome teaches the symptom, which is what you will recognize later during debugging or interviews.
This structure is easy to test and extend with schedulers, mixed precision, early stopping, and logging.
import torch
def train_one_epoch(model, loader, loss_fn, optimizer, device):
model.train()
total_loss = 0.0
total_correct = 0
total_examples = 0
for features, labels in loader:
features = features.to(device)
labels = labels.to(device)
logits = model(features)
loss = loss_fn(logits, labels)
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
batch_size = labels.size(0)
total_loss += loss.item() * batch_size
total_correct += (logits.argmax(dim=1) == labels).sum().item()
total_examples += batch_size
return {
"loss": total_loss / total_examples,
"accuracy": total_correct / total_examples,
}
@torch.no_grad()
def validate(model, loader, loss_fn, device):
model.eval()
total_loss = 0.0
total_correct = 0
total_examples = 0
for features, labels in loader:
features = features.to(device)
labels = labels.to(device)
logits = model(features)
loss = loss_fn(logits, labels)
batch_size = labels.size(0)
total_loss += loss.item() * batch_size
total_correct += (logits.argmax(dim=1) == labels).sum().item()
total_examples += batch_size
return {
"loss": total_loss / total_examples,
"accuracy": total_correct / total_examples,
}
import torch
x = torch.randn(4, 3)
print('topic:', 'PyTorch')
print('shape:', x.shape)
print('dtype:', x.dtype)
print('device:', x.device)
# Shape, dtype, and device checks catch many PyTorch mistakes early.
import torch
from torch import nn
model = nn.Sequential(nn.Linear(3, 4), nn.ReLU(), nn.Linear(4, 1))
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
loss_fn = nn.MSELoss()
x = torch.randn(8, 3)
y = torch.randn(8, 1)
loss = loss_fn(model(x), y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
print(float(loss))
Validate without model.eval()
Call model.eval() before validation.
Track loss.item() average per batch only.
Weight loss by batch size when averaging.
Learning PyTorch only as a term.
Learn it through a working example, a boundary case, and a failure case.
Skipping verification.
Always check output, state, logs, metrics, query results, or compiler feedback.
Changing many things at once while debugging.
Change one setting, input, or line, then inspect the result.
PyTorch accumulates gradients. zero_grad clears previous gradients before computing the next batch gradients.
Validation does not update weights, so no_grad reduces memory use and speeds up evaluation.
Start with one tiny example, trace every step, then compare it with a broken version.
Verify the visible result: output, state, log entry, metric, query result, compiler feedback, or rendered behavior.
It often combines vocabulary with behavior. The confusion drops when you trace the input, rule, result, and failure path.
Explore 500+ free tutorials across 20+ languages and frameworks.