# -*- coding:UTF-8 -*-
import time
import torch
import torch.nn.functional as F
from torch.autograd import Variable

 
def roi_pooling(input, rois, size=(3, 3), spatial_scale=1.0):
    assert rois.dim() == 2
    assert rois.size(1) == 5
    output = []
    rois = rois.data.float()
    num_rois = rois.size(0)
 
    rois[:, 1:].mul_(spatial_scale)
    rois = rois.long()
    for i in range(num_rois):
        roi = rois[i]
        im_idx = roi[0]
        # 首先在feature上裁剪出roi的区域,naroow 类似于切片 下面详讲
        # 然后通过adaptive_max_pool2d 输出指定size的feature
        im = input.narrow(0, im_idx, 1)[..., roi[2]:(roi[4] + 1), roi[1]:(roi[3] + 1)]
        output.append(F.adaptive_max_pool2d(im, size))
    
    output = torch.cat(output, 0)
    if has_backward:
        # output.backward(output.data.clone())
        output.sum().backward()
    return output

def create_rois(config):
 
    rois = torch.rand((config[2], 5))
    
    rois[:, 0] = rois[:, 0] * config[0]
    rois[:, 1:] = rois[:, 1:] * config[1]
    
    for j in range(config[2]):
        max_, min_ = max(rois[j, 1], rois[j, 3]), min(rois[j, 1], rois[j, 3])
        rois[j, 1], rois[j, 3] = min_, max_
        max_, min_ = max(rois[j, 2], rois[j, 4]), min(rois[j, 2], rois[j, 4])
        rois[j, 2], rois[j, 4] = min_, max_
    rois = torch.floor(rois)
    rois = Variable(rois, requires_grad=False)
    return rois
 
 
if __name__ == '__main__':
    # batch_size, img_size, num_rois
    config = [1, 6, 6]
    T = 1
    has_backward = True
 
    start = time.time()
    
    x = torch.randn((config[0], 2, config[2], config[2]),requires_grad=True)
    
    #rois = torch.tensor([[0,0,0,4,4],
                        [0,0,0,1,1]])
    #roi = rois[0] 
    #im = x.narrow(0, 0, 1)[..., roi[2]:(roi[4] + 1), roi[1]:(roi[3] + 1)]
    #print(im)
    
	rois = create_rois(config)
	# roi[0,1]表示左上角缩影 [2,3]表示右下角索引
	
    for t in range(T):
        print('\n')
        output = roi_pooling(x,rois)
        print(x,'\n\n\n',output)
        print(rois)
        print(x.shape,output.shape)
        
    print('time: {}, batch_size: {}, size: {}, num_rois: {}'.format((time.time() - start) / T,config[0],config[1],config[2]))
    

tensor narrow


''' ## narrow(dim0,dim1,dim2) dim0: 表示要处理的维度,可以理解为0位最外层的括号,每加1,则去掉一层括号 dim1: 表示在当前维度上,从第几个索引处 开始 dim2: 在当前维度上, dim0+dim1 data.shape = [2,5,5] data.narrow(0,0,2) == data[0:2,:,:] data.narrow(1,0,2) == data[:,0:2,:] data.narrow(2,0,3) == data[:,:,0:3] '''

import torch

data = torch.randn((2,5,5))
print(data,'\n')

tmp = data.narrow(2,0,3)
print(tmp.shape)

tmp2 = data[:,:,0:3]
print(tmp2.shape)

print(tmp,'\n')
print(tmp2,'\n')

''' tensor([[[-0.6959, 0.4332, 1.4319, -0.0854, -1.6173], [-0.8515, -0.2990, 0.4964, 1.6507, 0.1273], [ 0.2099, 0.1537, -0.1838, 0.1569, -0.7338], [-1.8595, -1.3069, -1.8491, -0.9116, 0.5553], [ 1.3213, -1.0289, -1.5322, 0.7186, -1.9894]], [[-1.1485, -1.5543, 1.3556, -0.1875, 0.8493], [ 0.4699, 0.5768, 1.0837, -0.8480, 0.1903], [ 0.3805, -0.0149, 1.2274, 0.1739, 0.8447], [ 0.3930, 0.1219, 2.6473, -0.0492, 0.5728], [ 0.1649, -0.1448, 1.0176, 0.0481, -0.2358]]]) torch.Size([2, 5, 3]) torch.Size([2, 5, 3]) tensor([[[-0.6959, 0.4332, 1.4319], [-0.8515, -0.2990, 0.4964], [ 0.2099, 0.1537, -0.1838], [-1.8595, -1.3069, -1.8491], [ 1.3213, -1.0289, -1.5322]], [[-1.1485, -1.5543, 1.3556], [ 0.4699, 0.5768, 1.0837], [ 0.3805, -0.0149, 1.2274], [ 0.3930, 0.1219, 2.6473], [ 0.1649, -0.1448, 1.0176]]]) tensor([[[-0.6959, 0.4332, 1.4319], [-0.8515, -0.2990, 0.4964], [ 0.2099, 0.1537, -0.1838], [-1.8595, -1.3069, -1.8491], [ 1.3213, -1.0289, -1.5322]], [[-1.1485, -1.5543, 1.3556], [ 0.4699, 0.5768, 1.0837], [ 0.3805, -0.0149, 1.2274], [ 0.3930, 0.1219, 2.6473], [ 0.1649, -0.1448, 1.0176]]]) '''