numpy和torch数据类型转化问题

在实际计算过程中,float类型使用最多,因此这里重点介绍numpy和torch数据float类型转化遇到的问题,其他类型同理。

numpy数据类型转化

  • numpy使用astype转化数据类型,float默认转化为64位,可以使用np.float32指定为32位
1
2
3
4
5
#numpy转化float类型
a= np.array([1,2,3])
a = a.astype(np.float)
print(a)
print(a.dtype)

[1. 2. 3.]
float64

  • 不要使用a.dtype指定数据类型,会使数据丢失
1
2
3
4
5
#numpy转化float类型
b= np.array([1,2,3])
b.dtype= np.float32
print(b)
print(b.dtype)

[1.e-45 3.e-45 4.e-45]
float32

  • 不要用float代替np.float,否则可能出现意想不到的错误
  • 不能从np.float64位转化np.float32,会报错
  • np.float64与np.float32相乘,结果为np.float64

在实际使用过程中,可以指定为np.float,也可以指定具体的位数,如np.float,不过直接指定np.float更方便。

torch数据类型转化

  • torch使用torch.float()转化数据类型,float默认转化为32位,torch中没有torch.float64()这个方法
1
2
3
4
# torch转化float类型
b = torch.tensor([4,5,6])
b = b.float()
b.dtype
torch.float32
  • np.float64使用torch.from_numpy转化为torch后也是64位的
1
2
3
print(a.dtype)
c = torch.from_numpy(a)
c.dtype

float64
torch.float64

  • 不要用float代替torch.float,否则可能出现意想不到的错误
  • torch.float32与torch.float64数据类型相乘会出错,因此相乘的时候注意指定或转化数据float具体类型

np和torch数据类型转化大体原理一样,只有相乘的时候,torch.float不一致不可相乘,np.float不一致可以相乘,并且转化为np.float64

numpy和tensor互转

  • tensor转化为numpy
1
2
3
4
5
6
import torch
b = torch.tensor([4.0,6])
# b = b.float()
print(b.dtype)
c = b.numpy()
print(c.dtype)

torch.int64
int64

  • numpy转化为tensor
1
2
3
4
5
6
7
import torch
import numpy as np
b= np.array([1,2,3])
# b = b.astype(np.float)
print(b.dtype)
c = torch.from_numpy(b)
print(c.dtype)

int32
torch.int32

可以看到,torch默认int型是64位的,numpy默认int型是32位的