ResNet18 pentru CIFAR-10

Model ResNet18 adaptat si antrenat pe CIFAR-10.

Performanta

Metrica Valoare
Accuracy 0.8763
Precision 0.8772
Recall 0.8763
F1-Score 0.8757

Utilizare cu template_antrenare_pytorch.py

# In template, seteaza:
HUGGINGFACE_REPO_ID = "Tudorx95/resnet18-cifar10"
MODEL_FILENAME = "ResNet18_CIFAR10.pth"

# Modelul se incarca automat in create_model()

Utilizare directa

import torch

# Incarca modelul complet
model = torch.load('ResNet18_CIFAR10.pth', map_location='cpu')
model.eval()

# Sau incarca doar weights
from torchvision.models import resnet18
model = resnet18(weights=None)
model.conv1 = torch.nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
model.maxpool = torch.nn.Identity()
model.fc = torch.nn.Linear(512, 10)
model.load_state_dict(torch.load('resnet18_cifar10_weights.pth'))

Clase CIFAR-10

0: airplane, 1: automobile, 2: bird, 3: cat, 4: deer, 5: dog, 6: frog, 7: horse, 8: ship, 9: truck

Antrenare

  • Epochs: 10
  • Batch Size: 128
  • Learning Rate: 0.001
  • Optimizer: Adam
  • Scheduler: StepLR (step=5, gamma=0.5)
  • Augmentare: RandomCrop, RandomHorizontalFlip
Downloads last month
58
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Dataset used to train Tudorx95/resnet18-cifar10