ResNet18 pentru CIFAR-10
Model ResNet18 adaptat si antrenat pe CIFAR-10.
Performanta
| Metrica | Valoare |
|---|---|
| Accuracy | 0.8763 |
| Precision | 0.8772 |
| Recall | 0.8763 |
| F1-Score | 0.8757 |
Utilizare cu template_antrenare_pytorch.py
# In template, seteaza:
HUGGINGFACE_REPO_ID = "Tudorx95/resnet18-cifar10"
MODEL_FILENAME = "ResNet18_CIFAR10.pth"
# Modelul se incarca automat in create_model()
Utilizare directa
import torch
# Incarca modelul complet
model = torch.load('ResNet18_CIFAR10.pth', map_location='cpu')
model.eval()
# Sau incarca doar weights
from torchvision.models import resnet18
model = resnet18(weights=None)
model.conv1 = torch.nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
model.maxpool = torch.nn.Identity()
model.fc = torch.nn.Linear(512, 10)
model.load_state_dict(torch.load('resnet18_cifar10_weights.pth'))
Clase CIFAR-10
0: airplane, 1: automobile, 2: bird, 3: cat, 4: deer, 5: dog, 6: frog, 7: horse, 8: ship, 9: truck
Antrenare
- Epochs: 10
- Batch Size: 128
- Learning Rate: 0.001
- Optimizer: Adam
- Scheduler: StepLR (step=5, gamma=0.5)
- Augmentare: RandomCrop, RandomHorizontalFlip
- Downloads last month
- 58