Pytorch入门 (8)

Pytorch入门 (8)

机器学习模型测试代码详解(面向编程小白)

这部分代码是对上面几篇文章的一个补充,旨在得到每张图像的预测和真实标签的对比。

# 测试模型

model.eval() # 设置为评估模式

with torch.no_grad():

correct = 0

total = 0

all_preds = []

all_labels = []

all_images = []

for images, labels in test_dataloader:

images = images.to(device)

labels = labels.to(device)

outputs = model(images)

_, predicted = torch.max(outputs.data, 1)

total += labels.size(0)

correct += (predicted == labels).sum().item()

# 保存一些图像、预测和标签用于可视化

all_images.append(images.cpu())

all_preds.append(predicted.cpu())

all_labels.append(labels.cpu())

# 只保存足够的样本用于可视化

if len(all_images) >= 1: # 只需要一个batch就足够了

break

print(f'测试集准确率: {100 * correct / total:.2f}%')

# 可视化预测结果

def visualize_predictions(images, predictions, labels, num_samples=9):

# 类别名称

classes = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',

'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']

# 创建一个3x3的子图

fig, axes = plt.subplots(3, 3, figsize=(12, 12))

# 展平axes数组以便于索引

axes = axes.flatten()

# 确保我们只展示指定数量的样本

images = torch.cat(all_images, 0)[:num_samples]

predictions = torch.cat(all_preds, 0)[:num_samples]

labels = torch.cat(all_labels, 0)[:num_samples]

for i in range(num_samples):

# 获取图像并转换为numpy数组

img = images[i].squeeze().numpy()

# 获取预测和真实标签

pred = classes[predictions[i]]

true_label = classes[labels[i]]

# 在子图上显示图像

axes[i].imshow(img, cmap='gray')

# 设置标题为预测和真实标签

title_color = 'green' if pred == true_label else 'red'

axes[i].set_title(f'Pred: {pred}\nTrue: {true_label}', color=title_color)

# 关闭坐标轴

axes[i].axis('off')

plt.tight_layout()

plt.show()

# 可视化一些预测结果

visualize_predictions(all_images, all_preds, all_labels)

这段代码主要是在测试一个已经训练好的机器学习模型,并且将预测结果进行可视化展示。我会分步骤解释这段代码的功能。

第一部分:测试模型性能

model.eval() # 设置为评估模式

这行代码告诉模型:“嘿,现在不是学习的时候,而是考试的时候了!”

eval()表示将模型设置为评估模式,这样模型就不会更新它的参数了

with torch.no_grad():

这行代码告诉PyTorch:“接下来的操作不需要计算梯度”

在测试时我们只需要得到预测结果,不需要反向传播来更新模型,这样可以节省内存并加速计算

correct = 0 # 记录正确预测的样本数

total = 0 # 记录总样本数

all_preds = [] # 存储所有预测结果

all_labels = [] # 存储所有真实标签

all_images = [] # 存储所有图像

初始化一些变量来记录测试结果

for images, labels in test_dataloader:

这是一个循环,从测试数据加载器中一批一批地获取数据

images是图像数据,labels是对应的真实标签

images = images.to(device)

labels = labels.to(device)

将数据送到计算设备上(可能是GPU或CPU)

outputs = model(images)

用模型对图像进行预测,得到输出

_, predicted = torch.max(outputs.data, 1)

torch.max找出每一行中最大值的索引

这行代码的意思是找出模型预测的最可能的类别

total += labels.size(0)

correct += (predicted == labels).sum().item()

累加样本总数

计算有多少预测和真实标签相匹配,并累加到correct变量

# 保存一些图像、预测和标签用于可视化

all_images.append(images.cpu())

all_preds.append(predicted.cpu())

all_labels.append(labels.cpu())

将当前批次的图像、预测结果和真实标签保存下来,用于后面的可视化

.cpu()是将数据从GPU移回CPU,因为可视化需要在CPU上进行

# 只保存足够的样本用于可视化

if len(all_images) >= 1: # 只需要一个batch就足够了

break

这段代码限制了我们只保存一批数据,因为我们只需要少量样本来展示

print(f'测试集准确率: {100 * correct / total:.2f}%')

计算并打印模型在测试集上的准确率

第二部分:可视化预测结果

def visualize_predictions(images, predictions, labels, num_samples=9):

定义一个函数来可视化预测结果

参数分别是图像、预测结果、真实标签,以及要展示的样本数量(默认为9个)

# 类别名称

classes = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',

'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']

这是一个列表,包含了所有可能的类别名称

这段代码是用于Fashion-MNIST数据集,包含了10类不同的服装或鞋子

# 创建一个3x3的子图

fig, axes = plt.subplots(3, 3, figsize=(12, 12))

# 展平axes数组以便于索引

axes = axes.flatten()

创建一个3行3列的图表布局,总共可以显示9张图像

flatten()是将二维数组转换成一维数组,便于后面的循环操作

# 确保我们只展示指定数量的样本

images = torch.cat(all_images, 0)[:num_samples]

predictions = torch.cat(all_preds, 0)[:num_samples]

labels = torch.cat(all_labels, 0)[:num_samples]

torch.cat是将列表中的所有张量连接起来

这里取前num_samples个样本用于展示

for i in range(num_samples):

对每一个样本进行处理

# 获取图像并转换为numpy数组

img = images[i].squeeze().numpy()

获取第i个图像

squeeze()去掉尺寸为1的维度

.numpy()将PyTorch张量转换为NumPy数组,因为Matplotlib使用NumPy数组来显示图像

# 获取预测和真实标签

pred = classes[predictions[i]]

true_label = classes[labels[i]]

根据预测和真实标签的索引,获取对应的类别名称

# 在子图上显示图像

axes[i].imshow(img, cmap='gray')

在第i个子图上显示图像

cmap='gray'表示使用灰度颜色映射,因为Fashion-MNIST的图像是灰度的

# 设置标题为预测和真实标签

title_color = 'green' if pred == true_label else 'red'

axes[i].set_title(f'Pred: {pred}\nTrue: {true_label}', color=title_color)

设置子图的标题,包含预测和真实标签

如果预测正确,标题显示为绿色;如果预测错误,标题显示为红色

# 关闭坐标轴

axes[i].axis('off')

关闭坐标轴,使图像看起来更整洁

plt.tight_layout()

plt.show()

tight_layout()调整子图之间的间距,使显示更美观

show()显示整个图表

# 可视化一些预测结果

visualize_predictions(all_images, all_preds, all_labels)

调用前面定义的函数,使用收集到的图像、预测和标签来可视化结果

总结:这段代码测试了一个机器学习模型在时尚物品识别任务上的表现,并将部分预测结果以可视化的方式展示出来,方便我们直观地了解模型的性能。

相关推荐

千图网VIP价格调整通知
腾讯QQ在哪里看成为好友多少天 手机QQ查看加好友时间方法
摩拜单车怎么样?
Category:亚洲各国建筑物
分类:牵绊概念礼装
小米电视设置在哪打开 小米电视设置界面打开方法
2022年中国智慧农业TOP100企业榜单发布
五笔输入法哪款比较好?6款五笔输入法下载推荐
《天國拯救》偷竊狀態持續時間介紹