技术交流QQ群:1027579432,欢迎你的加入!

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# @Date    : 2019-03-01 10:02:11
# @Author  : cdl (1217096231@qq.com)
# @Link    : https://github.com/cdlwhm1217096231/python3_spider
# @Version : $Id$

import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import torchvision.transforms as transforms
import torchvision.models as models
import torchvision.datasets as datasets
from torch.utils.data import DataLoader, Dataset
import pandas as pd
import numpy as np

# 1.pytorch基础:数据的加载与预处理
"""
pytorch通过torch.utils.data对一般常用的数据进行了封装,可以很容易实现对线程数据预读和批量加载
torchvision已经预先实现了常用图像的数据集,包括cifar10 imagenet coco mnist lsun等数据集,可以通过torchvision.datasets来进行调用。
"""

# 1.1 DataSet类来自torch.utils.data中
"""
为了能方便读取,需要将使用的数据包封装为DataSet类,自定义DataSet时,需要继承该类并重写两个成员方法:
    __getitem()__:定义了每次怎么读数据
    __len()__:定义了自定义数据集的大小
"""


class BulldozerDataSet(Dataset):
    def __init__(self, cvs_file):
        # 实现初始化方法,在初始化时将数据进行加载
        self.df = pd.read_csv(cvs_file)

    def __len__(self):
        # 返回df的长度
        return len(df)

    def __getitem(self, idx):
        # 根据索引,返回一列数据
        return self.df.iloc[idx].SalePrice


# 数据集定义已经完成,可以实例化一个对象进行访问
ds_demo = BulldozerDataSet("median_benchmark.csv")
# 由于实现了__len__()方法可以直接使用len获取数据总数
print(len(ds_demo))
# 用索引可以直接访问对应的数据
print(ds_demo[0])


# 1.2 DataLoader
"""
DataLoader提供了对DataSet的读取操作,常用的参数有:batch_size(每个batch的大小)、shuffle(是否打乱)、num_workers(加载数据的时候使用几个子线程)
"""
dl = DataLoader(ds_demo, batch_size=10, shuffle=True, num_workers=0)

# !!! DataLoader返回的是一个迭代器,可以使用迭代器分次获取数据
iter_data = iter(dl)
print(next(iter_data))
# 或者使用for循环来对其进行遍历
for i, data in enumerate(dl):
    print(i, '--->', data)
    break

# 1.3 torchvision包
"""
torchvision是pytorch专门用来处理图像的库
"""

# 1.3.1 torchvision.datasets
"""
可以理解为pytorch团队自定义的dataset,拿来就可以使用:CIFAR10、COCO、MNIST等数据集
"""
# train_set = datasets.MNIST(root="./data", train=True,
# download=True, transform=None)
"""
root:表示MNIST数据集的加载目录
train:表示是否加载数据库的训练集,false时加载测试集
download:表示是否自动下载MNIST数据集
transform:表示是否需要对数据进行预处理,none表示不进行预处理
"""

# 1.3.2 torchvision.models
"""
torchvision不仅提供常用的图片数据集,还提供训练好的模型,可以在加载后使用,或者在进行迁移学习时使用
torchvision.models模块的子模块中包含以下模型结构: AlexNet VGG ResNet SqueezeNet DenseNet
"""

# resnet18 = models.resnet18(pretrained=True)


# 1.3.3 trochvision.transforms
"""
transforms模块提供了一般的图像转换操作类,用于数据的处理和增广
"""
transform = transforms.Compose([
    transforms.RandomCrop(32, padding=4),  # 先四周填充0,在吧图像随机裁剪成32*32
    transforms.RandomHorizontalFlip(),  # 图像一半的概率翻转,一半的概率不翻转
    transforms.RandomRotation(-45, 45),  # 随机旋转
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465),
                         (0.229, 0.224, 0.225))  # R,G,B每层的归一化用到的均值和方差
])