import time
import torch
import torch.nn.functional as F
from torch.autograd import Variable
def roi_pooling(input, rois, size=(3, 3), spatial_scale=1.0):
assert rois.dim() == 2
assert rois.size(1) == 5
output = []
rois = rois.data.float()
num_rois = rois.size(0)
rois[:, 1:].mul_(spatial_scale)
rois = rois.long()
for i in range(num_rois):
roi = rois[i]
im_idx = roi[0]
im = input.narrow(0, im_idx, 1)[..., roi[2]:(roi[4] + 1), roi[1]:(roi[3] + 1)]
output.append(F.adaptive_max_pool2d(im, size))
output = torch.cat(output, 0)
if has_backward:
output.sum().backward()
return output
def create_rois(config):
rois = torch.rand((config[2], 5))
rois[:, 0] = rois[:, 0] * config[0]
rois[:, 1:] = rois[:, 1:] * config[1]
for j in range(config[2]):
max_, min_ = max(rois[j, 1], rois[j, 3]), min(rois[j, 1], rois[j, 3])
rois[j, 1], rois[j, 3] = min_, max_
max_, min_ = max(rois[j, 2], rois[j, 4]), min(rois[j, 2], rois[j, 4])
rois[j, 2], rois[j, 4] = min_, max_
rois = torch.floor(rois)
rois = Variable(rois, requires_grad=False)
return rois
if __name__ == '__main__':
config = [1, 6, 6]
T = 1
has_backward = True
start = time.time()
x = torch.randn((config[0], 2, config[2], config[2]),requires_grad=True)
[0,0,0,1,1]])
rois = create_rois(config)
for t in range(T):
print('\n')
output = roi_pooling(x,rois)
print(x,'\n\n\n',output)
print(rois)
print(x.shape,output.shape)
print('time: {}, batch_size: {}, size: {}, num_rois: {}'.format((time.time() - start) / T,config[0],config[1],config[2]))
tensor narrow
''' ## narrow(dim0,dim1,dim2) dim0: 表示要处理的维度,可以理解为0位最外层的括号,每加1,则去掉一层括号 dim1: 表示在当前维度上,从第几个索引处 开始 dim2: 在当前维度上, dim0+dim1 data.shape = [2,5,5] data.narrow(0,0,2) == data[0:2,:,:] data.narrow(1,0,2) == data[:,0:2,:] data.narrow(2,0,3) == data[:,:,0:3] '''
import torch
data = torch.randn((2,5,5))
print(data,'\n')
tmp = data.narrow(2,0,3)
print(tmp.shape)
tmp2 = data[:,:,0:3]
print(tmp2.shape)
print(tmp,'\n')
print(tmp2,'\n')
''' tensor([[[-0.6959, 0.4332, 1.4319, -0.0854, -1.6173], [-0.8515, -0.2990, 0.4964, 1.6507, 0.1273], [ 0.2099, 0.1537, -0.1838, 0.1569, -0.7338], [-1.8595, -1.3069, -1.8491, -0.9116, 0.5553], [ 1.3213, -1.0289, -1.5322, 0.7186, -1.9894]], [[-1.1485, -1.5543, 1.3556, -0.1875, 0.8493], [ 0.4699, 0.5768, 1.0837, -0.8480, 0.1903], [ 0.3805, -0.0149, 1.2274, 0.1739, 0.8447], [ 0.3930, 0.1219, 2.6473, -0.0492, 0.5728], [ 0.1649, -0.1448, 1.0176, 0.0481, -0.2358]]]) torch.Size([2, 5, 3]) torch.Size([2, 5, 3]) tensor([[[-0.6959, 0.4332, 1.4319], [-0.8515, -0.2990, 0.4964], [ 0.2099, 0.1537, -0.1838], [-1.8595, -1.3069, -1.8491], [ 1.3213, -1.0289, -1.5322]], [[-1.1485, -1.5543, 1.3556], [ 0.4699, 0.5768, 1.0837], [ 0.3805, -0.0149, 1.2274], [ 0.3930, 0.1219, 2.6473], [ 0.1649, -0.1448, 1.0176]]]) tensor([[[-0.6959, 0.4332, 1.4319], [-0.8515, -0.2990, 0.4964], [ 0.2099, 0.1537, -0.1838], [-1.8595, -1.3069, -1.8491], [ 1.3213, -1.0289, -1.5322]], [[-1.1485, -1.5543, 1.3556], [ 0.4699, 0.5768, 1.0837], [ 0.3805, -0.0149, 1.2274], [ 0.3930, 0.1219, 2.6473], [ 0.1649, -0.1448, 1.0176]]]) '''