1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94
|
require 'nn'; local matio = require 'matio'
nhidden = 20 ninput = 310 noutput = 3
net = nn.Sequential() net:add(nn.Linear(ninput, nhidden)) net:add(nn.Linear(nhidden, noutput))
criterion = nn.CrossEntropyCriterion()
train_input = matio.load('hw1/train_test/train_data.mat') train_label = matio.load('hw1/train_test/train_label.mat') test_input = matio.load('hw1/train_test/test_data.mat') test_label = matio.load('hw1/train_test/test_label.mat') train_label.train_label = train_label.train_label + 2 test_label.test_label = test_label.test_label + 2
trainset = {} trainset.data = train_input.train_data trainset.label = train_label.train_label setmetatable(trainset, {__index = function(t, i) return {t.data[i], t.label[i]} end} ); trainset.data = trainset.data:double()
function trainset:size() return self.data:size(1) end
mean = {} stdv = {} for i = 1,trainset.data:size(2) do mean[i] = trainset.data[{{},{i}}]:mean()
trainset.data[{{},{i}}]:add(-mean[i]) stdv[i] = trainset.data[{{},{i}}]:std()
trainset.data[{{},{i}}]:div(stdv[i]) end
trainer = nn.StochasticGradient(net, criterion) trainer.learningRate = 0.005 trainer.maxIteration = 10
trainer:train(trainset)
testset = {} testset.data = test_input.test_data testset.label = test_label.test_label setmetatable(testset, {__index = function(t, i) return {t.data[i], t.label[i]} end} ); testset.data = testset.data:double()
function testset:size() return self.data:size(1) end
mean = {} stdv = {} for i = 1,testset.data:size(2) do mean[i] = testset.data[{{},{i}}]:mean()
testset.data[{{},{i}}]:add(-mean[i]) stdv[i] = testset.data[{{},{i}}]:std()
testset.data[{{},{i}}]:div(stdv[i]) end
correct = 0 for i=1,testset:size() do local groundtruth = testset.label[i] local prediction = net:forward(testset.data[i]) local confidences, indices = torch.sort(prediction, true) if groundtruth[1] == indices[1] then correct = correct + 1 end end
print(correct, 100*correct/testset:size() .. ' % ')
|
近期评论