张量拼接

torch.cat(tensors,
			dim=0,
			out=None)

功能:将张量按维度dim进行拼接,不创建新的维度
tensors:张量序列
dim:要拼接的维度

torch.stack(tensors,
			dim=0,
			out=None)

功能:在新创建的维度dim上进行拼接,创建新的维度
tensors:张量序列
dim:要拼接的维度

张量切分

torch.chunk(input,
			chunks,
			dim=0)

功能:将张量按维度dim进行平均切分
返回值:张量列表
注意事项:若不能整除,最后一份张量小于其他张量
input:要切分的张量
chunks:要切分的份数
dim:要切分的维度

torch.split(tensor,
			split_size_or_sections,
			dim=0)

功能:将张量按维度dim进行切分
返回值:张量列表
tensor:要切分的张量
split_size_or_sections:为int时,表示每一份的长度;为list时,按list元素切分
dim: 要切分的维度

import torch.nn as nn
unfold = nn.Unfold(kernels=(3,3),padding=1,stride=1)

该函数能够实现对张量的滑窗切分,若原始张量维度为[N,C,H,W],则经过unfold后得到[N,C33,H*W]

张量索引

torch.index_select(
				input,
				dim,
				index,
				out=None)

功能:在维度dim上,按index索引数据
返回值:依index索引数据拼接的张量
input:要索引的张量
dim:要索引的维度
index:要索引数据的序号

t = torch.randint(0,9,size=(3,3))
idx = torch.tensor([0,2],dtype=torch.long) #注意这里必须是long型
t_select = torch.index_select(t,dim=0,index=idx)
torch.masked_select(input,
					mask,
					out=None)

功能:按mask中的True进行索引
返回值:一维张量
input: 要索引的张量
mask: 与input同形状的布尔类型张量

t = torch.randint(0,9,size=(3,3))
mask = t.ge(5)
t_select = torch.mask_select(t,mask)

张量变换

torch.transpose(input,
				dim0,
				dim1)
torch.t(input) #直接转置2维矩阵

功能:交换张量的两个维度
input:要交换的张量
dim0: 要交换的维度
dim1: 要交换的维度

torch.squeeze(input,dim=None,out=None)
torch.unsqueeze(input,dim,out=None)

功能:压缩和扩展维度