CIFAR-10 Image Classifier

This model classifies images into one of 10 categories from the CIFAR-10 dataset:

  • airplane
  • automobile
  • bird
  • cat
  • deer
  • dog
  • frog
  • horse
  • ship
  • truck

Model Description

  • Architecture: ResNet18 (pretrained on ImageNet)
  • Fine-tuned on: CIFAR-10 dataset
  • Test Accuracy: ~85-90%

Usage

import torch
import torchvision.models as models
import torch.nn as nn
from PIL import Image
from torchvision import transforms
from huggingface_hub import hf_hub_download

# Load model
model = models.resnet18(pretrained=False)
num_features = model.fc.in_features
model.fc = nn.Sequential(
    nn.Linear(num_features, 256),
    nn.ReLU(),
    nn.Dropout(0.3),
    nn.Linear(256, 10)
)

# Download and load weights
model_path = hf_hub_download(repo_id="sabarsbb/cifar10-image-classifier-v1", filename="best_model.pth")
checkpoint = torch.load(model_path, map_location='cpu')
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

# Inference
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

image = Image.open('your_image.jpg')
image_tensor = transform(image).unsqueeze(0)

with torch.no_grad():
    outputs = model(image_tensor)
    _, predicted = torch.max(outputs, 1)

classes = ['airplane', 'automobile', 'bird', 'cat', 'deer',
           'dog', 'frog', 'horse', 'ship', 'truck']
print(f"Predicted class: {classes[predicted.item()]}")

Training

The model was trained using:

  • Optimizer: Adam
  • Loss function: CrossEntropyLoss
  • Learning rate: 0.001 with ReduceLROnPlateau scheduler
  • Data augmentation: Random horizontal flip, rotation, color jitter

License

This model is released under the MIT License.

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support