Pytorch中torch.stack() 函数解析_python torch.stack_cv_lhp的博客-程序员宅基地

技术标签: python  Pytorch基础  深度学习  pytorch  

一. torch.stack()函数解析

1. 函数说明:

1.1 官网torch.stack(),函数定义及参数说明如下图所示:

函数定义及参数说明

1.2 函数功能

沿一个新维度对输入一系列张量进行连接,序列中所有张量应为相同形状,stack 函数返回的结果会新增一个维度。也即是把多个2维的张量凑成一个3维的张量;多个3维的凑成一个4维的张量…以此类推,也就是在增加新的维度上面进行堆叠。

1.3 参数列表

  • tensors :为一系列输入张量,类型为turple和List
  • dim :新增维度的(下标)位置,当dim = -1时默认最后一个维度;范围必须介于 0 到输入张量的维数之间,默认是dim=0,在第0维进行连接
  • 返回值:输出新增维度后的张量

2. 代码举例

2.1 dim = 0 : 在第0维进行连接,相当于在行上进行组合(输入张量为一维,输出张量为两维)

import torch
#二维输入张量a,b
a = torch.tensor([1, 2, 3])
b = torch.tensor([11, 22, 33])
c = torch.stack([a, b],dim=0)#在第0维进行连接,相当于在行上进行组合(输入张量为一维,输出张量为两维)
print(a)
print(b)
print(c)
输出结果如下:
tensor([1, 2, 3])
tensor([11, 22, 33])
tensor([[ 1,  2,  3],
        [11, 22, 33]])

2.2 dim = 1 :在第1维进行连接,相当于在对应行上面对列元素进行组合(输入张量为一维,输出张量为两维)

import torch
#二维输入张量a,b
a = torch.tensor([1, 2, 3])
b = torch.tensor([11, 22, 33])
c = torch.stack([a, b],dim=1)#在第1维进行连接,相当于在对应行上面对列元素进行组合(输入张量为一维,输出张量为两维)
print(a)
print(b)
print(c)
输出结果如下:
tensor([1, 2, 3])
tensor([11, 22, 33])
tensor([[ 1, 11],
        [ 2, 22],
        [ 3, 33]])

2.3 dim=0:表示在第0维进行连接,相当于在通道维度上进行组合(输入张量为两维,输出张量为三维),注意:此处输入张量维度为二维,因此dim最大只能为2。

import torch
#二维输入张量a,b
a = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
b = torch.tensor([[11, 22, 33], [44, 55, 66], [77, 88, 99]])
c = torch.stack([a, b],dim=0)#在第0维进行连接,相当于在通道维度上进行组合(输入张量为两维,输出张量为三维)
print(a)
print(b)
print(c)
输出结果如下所示:
tensor([[1, 2, 3],
        [4, 5, 6],
        [7, 8, 9]])
tensor([[11, 22, 33],
        [44, 55, 66],
        [77, 88, 99]])
tensor([[[ 1,  2,  3],
         [ 4,  5,  6],
         [ 7,  8,  9]],

        [[11, 22, 33],
         [44, 55, 66],
         [77, 88, 99]]])

2.4 dim=1:表示在第1维进行连接,相当于对相应通道中每个行进行组合,注意:此处输入张量维度为二维,因此dim最大只能为2。

import torch
#二维输入张量a,b
a = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
b = torch.tensor([[11, 22, 33], [44, 55, 66], [77, 88, 99]])
c = torch.stack([a, b], 1)#在第1维进行连接,相当于对相应通道中每个行进行组合
print(a)
print(b)
print(c)
输出结果如下所示:
tensor([[1, 2, 3],
        [4, 5, 6],
        [7, 8, 9]])
tensor([[11, 22, 33],
        [44, 55, 66],
        [77, 88, 99]])
tensor([[[ 1,  2,  3],
         [11, 22, 33]],

        [[ 4,  5,  6],
         [44, 55, 66]],

        [[ 7,  8,  9],
         [77, 88, 99]]])

2.5 dim=2:表示在第2维进行连接,相当于对相应行中每个列元素进行组合,注意:此处输入张量维度为二维,因此dim最大只能为2。

import torch
#二维输入张量a,b
a = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
b = torch.tensor([[11, 22, 33], [44, 55, 66], [77, 88, 99]])
c = torch.stack([a, b], 2)#在第2维进行连接,相当于对相应行中每个列元素进行组合
print(a)
print(b)
print(c)
输出结果如下所示:
tensor([[1, 2, 3],
        [4, 5, 6],
        [7, 8, 9]])
tensor([[11, 22, 33],
        [44, 55, 66],
        [77, 88, 99]])
tensor([[[ 1, 11],
         [ 2, 22],
         [ 3, 33]],

        [[ 4, 44],
         [ 5, 55],
         [ 6, 66]],

        [[ 7, 77],
         [ 8, 88],
         [ 9, 99]]])

2.6 dim=3:表示在第3维进行连接,相当于对相应行中每个列元素进行组合(输入维度大小为3维,因此dim=3最后一维始终代表为列),注意:此处输入张量维度为三维,因此dim最大只能为3。

import torch
#三维输入张量a,b
a = torch.tensor([[[1, 2, 3], [4, 5, 6], [7, 8, 9]],[[10, 20, 30], [40, 50, 60], [70, 80, 90]]])
b = torch.tensor([[[11, 22, 33], [44, 55, 66], [77, 88, 99]], [[110, 220, 330], [440, 550, 660], [770, 880, 990]]])
c = torch.stack([a, b], 3)#表示在第3维进行连接,相当于对相应行中每个列元素进行组合(最后一维是第三维,始终代表为列)
print(a)
print(b)
print(c)
输出结果如下所示:
tensor([[[ 1,  2,  3],
         [ 4,  5,  6],
         [ 7,  8,  9]],

        [[10, 20, 30],
         [40, 50, 60],
         [70, 80, 90]]])
tensor([[[ 11,  22,  33],
         [ 44,  55,  66],
         [ 77,  88,  99]],

        [[110, 220, 330],
         [440, 550, 660],
         [770, 880, 990]]])
tensor([[[[  1,  11],
          [  2,  22],
          [  3,  33]],

         [[  4,  44],
          [  5,  55],
          [  6,  66]],

         [[  7,  77],
          [  8,  88],
          [  9,  99]]],


        [[[ 10, 110],
          [ 20, 220],
          [ 30, 330]],

         [[ 40, 440],
          [ 50, 550],
          [ 60, 660]],

         [[ 70, 770],
          [ 80, 880],
          [ 90, 990]]]])

2.7 dim=4 (错误维度:因为此处输入张量维度为三维,所以dim最大只能为3,此处维度为4,因此会报错)

import torch
#三维输入张量a,b
a = torch.tensor([[[1, 2, 3], [4, 5, 6], [7, 8, 9]],[[10, 20, 30], [40, 50, 60], [70, 80, 90]]])
b = torch.tensor([[[11, 22, 33], [44, 55, 66], [77, 88, 99]], [[110, 220, 330], [440, 550, 660], [770, 880, 990]]])
c = torch.stack([a, b], 4)
print(a)
print(b)
print(c)
输出错误:
IndexError: Dimension out of range (expected to be in range of [-4, 3], but got 4)

版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://blog.csdn.net/flyingluohaipeng/article/details/125034358

智能推荐

IOS 推送消息 php做推送服务端(试验过可行)-程序员宅基地

IOS 推送消息 php做推送服务端 博客分类:ios IOS推送消息是许多IOS应用都具备的功能,最近也在研究这个功能,参考了很多资料终于搞定了,下面就把步骤拿出来分享下: iOS消息推送的工作机制可以简单的用下图来概括: Provider是指某个iPhone软件的Push服务器,APNS是Apple Push

【个人博客图片库】_个人博客图库_嗨Sirius的博客-程序员宅基地

为防止GitHub上面图片加载很慢的情况,决定以后将图片转存至CSDN服务器上,通过qiang内服务器的外链引用就能避免加载慢了_个人博客图库

python tornado 简单的form表单操作_tornado form-程序员宅基地

在web中最主要的应用就是前后端的数据交互,本文给出的是一个从前端form表单传递数据至后端,并返回给浏览器,在浏览器上显示的简单应用。import tornadoimport tornado.webimport tornado.ioloopclass indexHandler(tornado.web.RequestHandler): def get(self, *args, *_tornado form

GDB (中文速查表) CHEATSHEET_gdb cheatsheet-程序员宅基地

##############################################################################GDB CHEATSHEET (中文速查表) - by skywind (created on 2018/02/20)Version: 8, Last Modified: 2018/02/28 17:13https://github...._gdb cheatsheet

C语言-字符串函数_要保证*dest足够长,以容纳被复制进来的*src。*src中原有的字符不变。返回指向dest-程序员宅基地

字符串函数strlensize_t strlen ( const char * str );字符串已经 ‘\0’ 作为结束标志,strlen函数返回的是在字符串中 ‘\0’ 前面出现的字符个数(不包含 ‘\0’ )。参数指向的字符串必须要以 ‘\0’ 结束。注意函数的返回值为size_t,是无符号的(易错)strcpychar* strcpy(char * destination, const char * source );源字符串必须以 ‘\0’ 结束。会将源字符串中的_要保证*dest足够长,以容纳被复制进来的*src。*src中原有的字符不变。返回指向dest

如何去转载一篇博客_转载一:http://blog.csdn.net/guang_mang/article/detail-程序员宅基地

     就程序员在学习的时候,经常会遇到好的文章,想要转载,然后以后需要用的时候就能很快找到文章了,下面演示一遍转载,一定一定不能选原创,请尊重作者目标: 转载学习博客学习内容:1、 点开某篇文章如何右键查看源代码,或者审查元素,或者键盘F122、 ctrl + F 全文查找 article_content3、 右键 copy->copy outerHTML4、 打开markdown粘贴,选转载然后发布文章尊重原创一定要标注转载->超链接选上原作者的链接..._转载一:http://blog.csdn.net/guang_mang/article/details/78724142

随便推点

dnf 跨服 服务器 位置,dnf跨区怎么跨_dnf国服跨区表_快吧游戏-程序员宅基地

DNF国服跨区须知,什么时候跨区?跨区需要什么东西?跨区有什么限制?角色跨区有什么东西能转移?下面快吧小编就给大家带来7.28更新内容AFQ付费跨区解答,希望小伙伴们别懵了哦!DNF国服跨区须知 跨区表一览付费转区(已阉)Q:角色转服活动持续多久?A:角色转服开启时间:2016年7月28日~ 2016年8月11日,本次活动只允许转移一个角色。Q:转服费用是否可以使用代币券支付?A:不能,转服费用只...

利用Vim进行文件夹对比的三种方式_gvim 对比两个文件-程序员宅基地

前言最近经常使用vim, 心血来潮想研究了一下如何用Vim进行代码merge. 在Windows下有Beyond Compare和WinMerge等软件,可以比较两个目录结构及文件内容的异同,并以图形界面的形式呈现给用户。Vim有的vimdiff可以进行文件内容的对比和merge操作,但是很遗憾的是只针对单个文件比较,文件夹的比较就无能为力了。在网上这方面的介绍不多,只搜到https://bl..._gvim 对比两个文件

Activity的启动模式与startActivityForResult的关系-程序员宅基地

Activity的启动方式分为四种,分别为standard,singleTop,singleTask,singleInstancestartActivityForResult方法能够起效:standard和singleTopstartActivityForResult方法不能够起效:singleTask和singleInstance1、只要将被启动的Activity属性设置为singl

计算机开机后黑屏一闪一闪怎么办,电脑屏幕老一闪一闪的,一会黑屏一会又亮了,有时...-显示屏闪黑屏重新开机...-程序员宅基地

电脑屏幕老一闪一闪的,一会黑屏一会又亮了,有时...电脑屏幕闪烁,一会黑屏的原因和解决方法一:电脑主机故障(1)主机电源引起的故障,检查电源并修理。(2)配件质量引起的故障。这时用替换法更换下显示卡,内存,甚至主板,CPU试试,是最快捷的解决办。(3)配件间的连接质量 。内存显卡等等与主板间的插接不正确或有松动造成接触不良是引发黑屏故障的主要原因,可以更换显卡。(4)超频引起的黑屏故障,过度超频或..._电脑屏幕黑屏一闪一闪

Myeclipse2016 安装与破解_myeclipse2016安装好慢-程序员宅基地

一、准备(下载)所需文件 文件名:Myeclipse2016 百度云链接:http://pan.baidu.com/s/1sl8sd33 提取密码:ZUFE二、安装软件(1).在下载的文件夹中,打开MyEclipse2016,进行安装 (2).安装简单,不作赘述 三、破解软件(1).Myecplise破解工具文件夹,打_myeclipse2016安装好慢

利用SOLR搭建企业搜索平台 之九(solr的查询语法)-程序员宅基地

solr的一些查询语法 1. 首先假设我的数据里fields有:name, tel, address 预设的搜寻是name这个字段, 如果要搜寻的数据刚好就是 name 这个字段,就不需要指定搜寻字段名称. 2. 查询规则: 如欲查询特定字段(非预设字段),请在查询词前加上该字段名称加 “:” (不包含”号) 符号, 例如: address:北京市海淀区上地软件园 tel:88x...

推荐文章

热门文章

相关标签