transforms.ToPILImage(): pic should be Tensor or ndarray
参考资料
问题描述
对UNet网络完成训练,需要输入一张图片,测试输出的概率图是否接近分割出的人脸。网络的输入为(num, 3, w, h)
4维度tensor。输出仍为(num, 1, w, h)
4维度的tensor,我的目标是将输出的tensor转为图像,代码如下:
1 | inputs = inputs.view(-1, c, h, w) |
运行出现错误transforms.ToPILImage(): pic should be Tensor or ndarray
。
解决问题
查阅资料1,出错是因为
All images in torchvision have to be represented as 3-dimensional tensors of the form [Channel, Height, Width]. I’m guessing your float tensor is a 2d tensor (height x width). For example, this works:
即torchvision中的所有图像必须是三维的tensor表示的,而我代码中的outputs未经处理时,是四维的。所以需要进行维度转换。修改后代码如下:
1 | inputs = inputs.view(-1, c, h, w) |