2.加载h5文件 import h5py file = h5py.File('.h5','r') file.keys() #查看h5文件的keys x=np.array(file['train_x']) x = x.astype('float32')/255 x = torch.from_numpy(x) # x的维度为 m * h * w * c,需要转换维度 x.permute(0,3,1,2)
defforward(self,x): x = self.conv1(x) x = self.avg_pool1(x) x = self.conv2(x) x = self.avg_pool2(x) x = self.flatten(x) x = self.relu(self.l1(x)) x = self.l2(x) return x
device = torch.device('cuda'if torch.cuda.is_available() else'cpu') model = Model() model= model.to(device)