• 首页 首页 icon
  • 工具库 工具库 icon
    • IP查询 IP查询 icon
  • 内容库 内容库 icon
    • 快讯库 快讯库 icon
    • 精品库 精品库 icon
    • 问答库 问答库 icon
  • 更多 更多 icon
    • 服务条款 服务条款 icon

增减维度numpy和torch的squeeze、unsqueeze理解

武飞扬头像
寻找永不遗憾
帮助1

1 为何要增减维度

神经网络conv2d的输入必须是四维的(batch,channel,height,width),前处理或者后处理通常需要维度扩充或者维度压缩,必须维度匹配!
一个减少维度,一个增加维度,增加和减少的维度只能是1(单维度)。

numpy中squeeze函数,无unsqueeze函数,numpy中增加维度用np.expand_dims(x, axis)函数,可参考链接
torch的tensor中,两个函数都有。

2 numpy中的squeeze 函数

解释:
从数组的形状中删除单维度条目,即把shape中为1的维度去掉,相当于减少维度

用法:

arr_1 = numpy.squeeze(arr, axis = None)

arr表示输入的数组;
axis的取值可为None或0,默认为None,表示删除所有shape为1的维度。axis为0表示删除 一层 shape为1的维度

举例:

import numpy as np

arr = np.array([[[[1,2,3],[4,5,6]]]])
print(type(arr), arr, arr.shape, sep='\n')
print("==========================")

arr_1 = np.squeeze(arr, axis=0)
print(type(arr_1), arr_1, arr_1.shape, sep='\n')
print("==========================")

arr_2 = np.squeeze(arr, axis=None)
print(type(arr_2), arr_2, arr_2.shape, sep='\n')

输出:

<class 'numpy.ndarray'>
[[[1 2 3]
  [4 5 6]]]
(1, 2, 3)
==========================
<class 'numpy.ndarray'>
[[1 2 3]
 [4 5 6]]
(2, 3)
==========================
<class 'numpy.ndarray'>
[[1 2 3]
 [4 5 6]]
(2, 3)

3 torch中的squeeze 函数

举例:

import torch

arr = torch.Tensor(1, 3, 1, 5)
print(type(arr), arr, arr.shape, sep='\n')
print("==========================")

# 里面的数字表示压缩哪个维度,依旧只有维度为1才能压
arr_1 = arr.squeeze(0)          # 压缩第一维度,且第一维度是1,可压缩
print(type(arr_1), arr_1, arr_1.shape, sep='\n')
print("==========================")

arr_2 = arr.squeeze(1)        # 压缩第二维度,但第二维度不是1,故不可压缩
print(type(arr_2), arr_2, arr_2.shape, sep='\n')
print("==========================")

arr_3 = arr.squeeze(2)        # 压缩第三维度,且第三维度是1,可压缩
print(type(arr_3), arr_3, arr_3.shape, sep='\n')
学新通

输出:

<class 'torch.Tensor'>
tensor([[[[1.9349e-19, 4.5445e 30, 4.7429e 30, 7.1354e 31, 7.1118e-04]],

         [[1.7444e 28, 7.3909e 22, 1.8727e 31, 1.4182e-19, 4.6168e 24]],

         [[4.2964e 24, 1.2514e-14, 8.9634e-33, 7.1345e 31, 7.1118e-04]]]])
torch.Size([1, 3, 1, 5])
==========================
<class 'torch.Tensor'>
tensor([[[1.9349e-19, 4.5445e 30, 4.7429e 30, 7.1354e 31, 7.1118e-04]],

        [[1.7444e 28, 7.3909e 22, 1.8727e 31, 1.4182e-19, 4.6168e 24]],

        [[4.2964e 24, 1.2514e-14, 8.9634e-33, 7.1345e 31, 7.1118e-04]]])
torch.Size([3, 1, 5])
==========================
<class 'torch.Tensor'>
tensor([[[[1.9349e-19, 4.5445e 30, 4.7429e 30, 7.1354e 31, 7.1118e-04]],

         [[1.7444e 28, 7.3909e 22, 1.8727e 31, 1.4182e-19, 4.6168e 24]],

         [[4.2964e 24, 1.2514e-14, 8.9634e-33, 7.1345e 31, 7.1118e-04]]]])
torch.Size([1, 3, 1, 5])
==========================
<class 'torch.Tensor'>
tensor([[[1.9349e-19, 4.5445e 30, 4.7429e 30, 7.1354e 31, 7.1118e-04],
         [1.7444e 28, 7.3909e 22, 1.8727e 31, 1.4182e-19, 4.6168e 24],
         [4.2964e 24, 1.2514e-14, 8.9634e-33, 7.1345e 31, 7.1118e-04]]])
torch.Size([1, 3, 5])
学新通

4 torch中的unsqueeze 函数

解释:
通过unsuqeeze(int)中的int整数,增加一个维度,int整数表示维度增加到哪儿去,且维度为1。

举例:

import torch

arr = torch.Tensor(3, 5)
print(type(arr), arr, arr.shape, sep='\n')
print("==========================")

# 本身是二维,增加一维变三维,可通过0,1,2三个数字来控制维度增加到哪
arr_1 = arr.unsqueeze(0)
print(type(arr_1), arr_1, arr_1.shape, sep='\n')
print("==========================")

arr_2 = arr.unsqueeze(1)
print(type(arr_2), arr_2, arr_2.shape, sep='\n')
print("==========================")

arr_3 = arr.unsqueeze(2)        # 数字再大就报错了
print(type(arr_3), arr_3, arr_3.shape, sep='\n')
学新通

输出:

<class 'torch.Tensor'>
tensor([[3.2483e 33, 1.9690e-19, 6.8589e 22, 1.3340e 31, 1.1708e-19],
        [7.2128e 22, 9.2216e 29, 7.5546e 31, 1.6932e 22, 3.0728e 32],
        [2.9514e 29, 2.8940e 12, 7.5338e 28, 1.8037e 28, 3.4740e-12]])
torch.Size([3, 5])
==========================
<class 'torch.Tensor'>
tensor([[[3.2483e 33, 1.9690e-19, 6.8589e 22, 1.3340e 31, 1.1708e-19],
         [7.2128e 22, 9.2216e 29, 7.5546e 31, 1.6932e 22, 3.0728e 32],
         [2.9514e 29, 2.8940e 12, 7.5338e 28, 1.8037e 28, 3.4740e-12]]])
torch.Size([1, 3, 5])
==========================
<class 'torch.Tensor'>
tensor([[[3.2483e 33, 1.9690e-19, 6.8589e 22, 1.3340e 31, 1.1708e-19]],

        [[7.2128e 22, 9.2216e 29, 7.5546e 31, 1.6932e 22, 3.0728e 32]],

        [[2.9514e 29, 2.8940e 12, 7.5338e 28, 1.8037e 28, 3.4740e-12]]])
torch.Size([3, 1, 5])
==========================
<class 'torch.Tensor'>
tensor([[[3.2483e 33],
         [1.9690e-19],
         [6.8589e 22],
         [1.3340e 31],
         [1.1708e-19]],

        [[7.2128e 22],
         [9.2216e 29],
         [7.5546e 31],
         [1.6932e 22],
         [3.0728e 32]],

        [[2.9514e 29],
         [2.8940e 12],
         [7.5338e 28],
         [1.8037e 28],
         [3.4740e-12]]])
torch.Size([3, 5, 1])
学新通

这篇好文章是转载于:学新通技术网

  • 版权申明: 本站部分内容来自互联网,仅供学习及演示用,请勿用于商业和其他非法用途。如果侵犯了您的权益请与我们联系,请提供相关证据及您的身份证明,我们将在收到邮件后48小时内删除。
  • 本站站名: 学新通技术网
  • 本文地址: /boutique/detail/tanhgfeecg
系列文章
更多 icon
同类精品
更多 icon
继续加载