最终正确率89左右 ,以后继续改进一下
导入相关并读取数据
1 | import torch |
将树叶的类别由字符串转换为数字
1 | #注意一定要sort一下,以确保每次运行的结果都是唯一的。因为set的结果不唯一 |
自定义dataset
1 | class Mydata(Dataset): |
定义dataloder
1 | train_data = Mydata(train_file, img_file, 'train') |
网络模型及超参数
1 | def get_net(): |
训练
1 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
预测
1 | net = get_net() |
其他
1 | #将Image转换为np,注意输出维度的变化,224*224*3 |