from paddle import fluid
import numpy as np
from paddle.fluid.dygraph.base import to_variable
from paddle.fluid import dygraph as dygraph
from paddle.fluid.dygraph import BatchNorm, Conv2D, Sequential, Pool2D
from paddle.fluid.layers import relu


class BasicBlock(dygraph.Layer):
    def __init__(self, in_channels, out_channels, stride=1):
        super(BasicBlock, self).__init__()

        self.conv1 = Conv2D(in_channels, out_channels, filter_size=3, stride=stride, padding=1)
        self.bn1 = BatchNorm(out_channels, act='relu')
        self.conv2 = Conv2D(out_channels, out_channels, filter_size=3, padding=1)
        self.bn2 = BatchNorm(out_channels)

        if in_channels != out_channels or stride != 1:
            self.shortcut = Conv2D(in_channels, out_channels, filter_size=1, stride=stride)
            self.bn3 = BatchNorm(out_channels)
        else:
            self.shortcut = None

    def forward(self, x):
        residual = x
        x = self.bn1(self.conv1(x))
        x = self.bn2(self.conv2(x))

        if self.shortcut:
            residual = self.bn3(self.shortcut(residual))

        return relu(x + residual)

class Resnet18(dygraph.Layer):
    def __init__(self):
        super(Resnet18, self).__init__()
        self.inplanes = [64, 128, 256, 512]
        self.blocknums = [2, 2, 2, 2]
        self.stride = [1, 2, 2, 2]
        self.FirstConv = Conv2D(3, 64, filter_size=7, stride=2, padding=3)
        self.Pool = Pool2D(pool_type='max', pool_size=3, pool_stride=2, pool_padding=1)

        self.block1 = self.add_layer(0, self.blocknums[0], self.inplanes[0], self.inplanes[0], self.stride[0])
        self.block2 = self.add_layer(1, self.blocknums[1], self.inplanes[0], self.inplanes[1], self.stride[1])
        self.block3 = self.add_layer(2, self.blocknums[2], self.inplanes[1], self.inplanes[2], self.stride[2])
        self.block4 = self.add_layer(3, self.blocknums[3], self.inplanes[2], self.inplanes[3], self.stride[3])

        self.globalPool = Pool2D(pool_type='avg', global_pooling=True)
    def add_layer(self, block_index, block_repeats, in_channels, out_channels, stride=1):
        blocklist = Sequential()
        blocklist.add_sublayer('Block{}_{}'.format(block_index, 0), BasicBlock(in_channels, out_channels, stride))

        for num in range(1, block_repeats):
            blocklist.add_sublayer('Block{}_{}'.format(block_index, num), BasicBlock(out_channels, out_channels))
        return blocklist

    def forward(self, x):
        x = self.FirstConv(x)
        x = self.Pool(x)
        x = self.block1(x)
        x = self.block2(x)
        x = self.block3(x)
        x = self.block4(x)

        return self.globalPool(x)

with fluid.dygraph.guard():
    x = np.random.randn(5, 3, 224, 224).astype('float32')
    x = to_variable(x)

    net = Resnet18()
    out = net(x)

    print(out.shape)