torch使用自己的数据集

在torch中使用自己的数据集

  • 先将数据转化为tensor

  • 然后将tensor数据转为torch能识别的Dataset

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    import torch.utils.data as Data

    tensor_x,tensor_y= torch.Tensor(x),torch.Tensor(y)
    dataset = Data.TensorDataset(data_tensor=tensor_x, target_tensor=tensor_y)

    loader = Data.DataLoader(
    dataset=dataset,
    batch_size=128,
    shuffle=True,
    num_workers=2, # 多线程读取数据
    )

    # 在整套数据上训练3次
    for epoch in range(3):
    for step, (batch_x, batch_y) in enumerate(loader):
    print (batch_x.size(), batch_y.size())
    # 真正训练时还要放到Variable中
    b_x = Variable(batch_x)
    b_y = Variable(batch_y)