
Overview
这篇文章将介绍pytorch中的torchvision模块。torchvision主要包含三个模块:datasets,models和transforms。
->官方英文介绍
->pytorch中文文档
Function
torchvision.datasets
torchvision.models
目前,models模块支持Alexnet,VGG,resnet,squeezeNet,denseNet和inception V3这几个模型。可以直接构建一个网络结构(参数随机化):
import torchvision.models as models
resnet18 = models.resnet18()
alexnet = models.alexnet()
vgg16 = models.vgg16()
squeezenet = models.squeezenet1_0()
densenet = models.densenet161()
inception = models.inception_v3()
如果想得到预训练后的网络,就把pretrained这个参数置成true:
import torchvision.models as models
resnet18 = models.resnet18(pretrained=True)
以VGG网络模型为例,有以下结构:
torchvision.models.vgg11(pretrained=False, **kwargs)
torchvision.models.vgg11_bn(pretrained=False, **kwargs)
torchvision.models.vgg13(pretrained=False, **kwargs)
torchvision.models.vgg13_bn(pretrained=False, **kwargs)
torchvision.models.vgg16(pretrained=False, **kwargs)
torchvision.models.vgg16_bn(pretrained=False, **kwargs)
torchvision.models.vgg19(pretrained=False, **kwargs)
torchvision.models.vgg19_bn(pretrained=False, **kwargs)




近期评论