transforms.ToPILImage(): pic should be Tensor or ndarray

参考资料

  1. transforms.ToPILImage(): pic should be Tensor or ndarray

问题描述

对UNet网络完成训练,需要输入一张图片,测试输出的概率图是否接近分割出的人脸。网络的输入为(num, 3, w, h)4维度tensor。输出仍为(num, 1, w, h)4维度的tensor,我的目标是将输出的tensor转为图像,代码如下:

1
2
3
4
5
6
7
inputs = inputs.view(-1, c, h, w)
#inputs = inputs.cuda()

inputs = Variable(inputs, volatile=True)
outputs = net(inputs)
print(outputs.shape)
transforms.ToPILImage()(outputs).convert('L').save('test2.jpg')

运行出现错误transforms.ToPILImage(): pic should be Tensor or ndarray

解决问题

查阅资料1,出错是因为

All images in torchvision have to be represented as 3-dimensional tensors of the form [Channel, Height, Width]. I’m guessing your float tensor is a 2d tensor (height x width). For example, this works:

即torchvision中的所有图像必须是三维的tensor表示的,而我代码中的outputs未经处理时,是四维的。所以需要进行维度转换。修改后代码如下:

1
2
3
4
5
6
7
8
9
inputs = inputs.view(-1, c, h, w)
#inputs = inputs.cuda()

inputs = Variable(inputs, volatile=True)
outputs = net(inputs)
print(outputs.shape)
outputs = outputs.view(1, h, w) # 转成3维的 很重要!!!
print(type(outputs))
transforms.ToPILImage()(outputs).convert('L').save('test2.jpg')