pytorch-torchvision

Overview

  这篇文章将介绍pytorch中的torchvision模块。torchvision主要包含三个模块:datasets,models和transforms。
  
->官方英文介绍
  ->pytorch中文文档

Function

torchvision.datasets

torchvision.models

  目前,models模块支持Alexnet,VGG,resnet,squeezeNet,denseNet和inception V3这几个模型。可以直接构建一个网络结构(参数随机化):

import torchvision.models as models

resnet18 = models.resnet18()
alexnet = models.alexnet()
vgg16 = models.vgg16()
squeezenet = models.squeezenet1_0()
densenet = models.densenet161()
inception = models.inception_v3()

  如果想得到预训练后的网络,就把pretrained这个参数置成true

import torchvision.models as models
resnet18 = models.resnet18(pretrained=True)

以VGG网络模型为例,有以下结构:

torchvision.models.vgg11(pretrained=False, **kwargs)
torchvision.models.vgg11_bn(pretrained=False, **kwargs)
torchvision.models.vgg13(pretrained=False, **kwargs)
torchvision.models.vgg13_bn(pretrained=False, **kwargs)
torchvision.models.vgg16(pretrained=False, **kwargs)
torchvision.models.vgg16_bn(pretrained=False, **kwargs)
torchvision.models.vgg19(pretrained=False, **kwargs)
torchvision.models.vgg19_bn(pretrained=False, **kwargs)

Threads:addjob([id], callback, [endcallback], […])