CIFAR-10 Image Classifier
This model classifies images into one of 10 categories from the CIFAR-10 dataset:
- airplane
- automobile
- bird
- cat
- deer
- dog
- frog
- horse
- ship
- truck
Model Description
- Architecture: ResNet18 (pretrained on ImageNet)
- Fine-tuned on: CIFAR-10 dataset
- Test Accuracy: ~85-90%
Usage
import torch
import torchvision.models as models
import torch.nn as nn
from PIL import Image
from torchvision import transforms
from huggingface_hub import hf_hub_download
# Load model
model = models.resnet18(pretrained=False)
num_features = model.fc.in_features
model.fc = nn.Sequential(
nn.Linear(num_features, 256),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(256, 10)
)
# Download and load weights
model_path = hf_hub_download(repo_id="sabarsbb/cifar10-image-classifier-v1", filename="best_model.pth")
checkpoint = torch.load(model_path, map_location='cpu')
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
# Inference
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
image = Image.open('your_image.jpg')
image_tensor = transform(image).unsqueeze(0)
with torch.no_grad():
outputs = model(image_tensor)
_, predicted = torch.max(outputs, 1)
classes = ['airplane', 'automobile', 'bird', 'cat', 'deer',
'dog', 'frog', 'horse', 'ship', 'truck']
print(f"Predicted class: {classes[predicted.item()]}")
Training
The model was trained using:
- Optimizer: Adam
- Loss function: CrossEntropyLoss
- Learning rate: 0.001 with ReduceLROnPlateau scheduler
- Data augmentation: Random horizontal flip, rotation, color jitter
License
This model is released under the MIT License.