共有1000个分类， 1400万张图片. 大量模型使用imagenet作为基准数据集训练模型
ImageNet dataset has over 14 million images maintained by Stanford University. It is extensively used for a large variety of Image related deep learning projects. The images belong to various classes or labels. Even though we can use both the terms interchangeably, we will stick to classes. The aim of the pre-trained models like AlexNet and ResNet101 is to take an image as an input and predict it’s class.
- Reading the input image
- Performing transformations on the image. For example – resize, center crop, normalization, etc.
- Forward Pass: Use the pre-trained weights to find out the output vector. Each element in this output vector describes the confidence with which the model predicts the input image to belong to a particular class.
- Based on the scores obtained (elements of the output vector we mentioned in step-3), display the predictions.
from torchvision import models import torch dir(models) alexnet = models.alexnet(pretrained=True) # You will see a similar output as below # Downloading: "https://download.pytorch.org/models/alexnet-owt- 4df8aa71.pth" to /home/hp/.cache/torch/checkpoints/alexnet-owt-4df8aa71.pth
from torchvision import transforms transform = transforms.Compose([ # transforms.Resize(256), # transforms.CenterCrop(224), # transforms.ToTensor(), # transforms.Normalize( # mean=[0.485, 0.456, 0.406], # std=[0.229, 0.224, 0.225] # )])
Line : Here we are defining a variable transform which is a combination of all the image transformations to be carried out on the input image.
Line : Resize the image to 256×256 pixels.
Line : Crop the image to 224×224 pixels about the center.
Line : Convert the image to PyTorch Tensor data type.
Line [5-7]: Normalize the image by setting its mean and standard deviation to the specified values.
# Import Pillow from PIL import Image img = Image.open("dog.jpg") img_t = transform(img) batch_t = torch.unsqueeze(img_t, 0) alexnet.eval() out = alexnet(batch_t) print(out.shape)
with open('imagenet_classes.txt') as f: classes = [line.strip() for line in f.readlines()] _, index = torch.max(out, 1) percentage = torch.nn.functional.softmax(out, dim=1) * 100 print(labels[index], percentage[index].item()) _, indices = torch.sort(out, descending=True) [(labels[idx], percentage[idx].item()) for idx in indices[:5]]