defforward(self,x): x = self.conv1(x) x = self.conv2(x) x = self.conv3(x) x = self.conv4(x) x = self.conv5(x) x = self.average_pool(x) x = self.flatten(x) x = self.fc(x) return x
# a = torch.ones((1000,3,64,64)) # model = Resnet() # output = model(a) # print(output.shape)
epoch = 50 for i in range(epoch): print("----------第{}轮次开始--------".format(i+1)) loss_cnt = 0 acc = 0 for i, (imgs, labels) in enumerate(train_loader): imgs = imgs.to(device) labels = labels.to(device)
#64 * 6 output = model(imgs) loss = loss_func(output, labels)
loss_cnt += loss acc += (output.argmax(1) == labels).sum()