1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84
| import torch import torchvision from torch import nn from torch.utils.data import DataLoader
train_data = torchvision.datasets.MNIST('data/mnist',transform=torchvision.transforms.ToTensor(),train=True,download=True) test_data = torchvision.datasets.MNIST('data/mnist',transform=torchvision.transforms.ToTensor(),train=False,download=True)
train_loader = DataLoader(train_data,batch_size=64,shuffle=True) test_loader = DataLoader(test_data,batch_size=64,shuffle=True)
train_len = len(train_data) test_len = len(test_data)
class Model(nn.Module): def __init__(self): super().__init__() self.conv1 = nn.Conv2d(1,6,5) self.avg_pool1 = nn.AvgPool2d(2,stride=2) self.conv2 = nn.Conv2d(6,16,5) self.avg_pool2 = nn.AvgPool2d(2,stride=2) self.flatten = nn.Flatten() self.l1 = nn.Linear(256,64) self.l2 = nn.Linear(64,10) self.relu = nn.ReLU()
def forward(self,x): x = self.conv1(x) x = self.avg_pool1(x) x = self.conv2(x) x = self.avg_pool2(x) x = self.flatten(x) x = self.relu(self.l1(x)) x = self.l2(x) return x
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = Model() model= model.to(device)
loss_func = nn.CrossEntropyLoss() loss_func = loss_func.to(device) optim = torch.optim.SGD(model.parameters(),lr=0.01)
epoch = 10 for i in range(epoch):
print('----------------第{}次迭代开始-----------'.format(i+1)) loss_cnt = 0 acc = 0 for data in train_loader: imgs, labels = data imgs = imgs.to(device) labels = labels.to(device)
y_hat = model(imgs)
acc += (y_hat.argmax(1) == labels).sum() loss = loss_func(y_hat, labels) loss_cnt += loss
optim.zero_grad() loss.backward() optim.step()
print("第{}次的loss为{}".format(i + 1, loss_cnt)) print("第{}次的acc为{}".format(i + 1, acc/train_len))
with torch.no_grad(): acc = 0 for data in test_loader: imgs, labels = data imgs = imgs.to(device) labels = labels.to(device)
y_hat = model(imgs)
acc += (y_hat.argmax(1) == labels).sum()
print("acc={}".format(acc/test_len))
|