""" 连续小波变换 CWT 参考论文:https://www.mdpi.com/2076-3417/8/7/1102/html morlet 小波在轴承故障诊断中比较常用 """
import numpy as np
import pywt
import matplotlib.pyplot as plt
import pandas as pd
import math
import os


def CWT(data, fs=25600):
    t = np.arange(0, len(data)) / fs
    # wavename = "cgau8" # cgau8 小波
    wavename = "morl"  # morlet 小波
    # wavename = "cmor3-3" # cmor 小波

    totalscale = 256
    fc = pywt.central_frequency(wavename)  # 中心频率
    cparam = 2 * fc * totalscale
    scales = cparam / np.arange(totalscale, 1, -1)
    [cwtmatr, frequencies] = pywt.cwt(data, scales, wavename, 1.0 / fs)  # 连续小波变换
    plt.figure(figsize=(12, 6))
    ax1 = plt.subplot(1,2,1)
    plt.plot(t, data)
    plt.xlabel("Time(s)", fontsize = 14)
    plt.ylabel("Amplitude(g)", fontsize=14)
    ax2 = plt.subplot(1,2,2)
    plt.contourf(t, frequencies, abs(cwtmatr))  # 画等高线图

    yt = [15.625, 31.25, 62.5, 125, 250, 500, 1000, 2000, 4000, 8000, 16000, 32000]
    ax2.set_yscale('log')
    ax2.set_yticks(yt)
    ax2.set_yticklabels(yt)

    # print("min(frequencies):", min(frequencies))
    # print("max(frequencies):", max(frequencies))
    ax2.set_ylim([min(frequencies), max(frequencies)])

    plt.xlabel("Time(s)", fontsize = 14)
    plt.ylabel("Frequency(Hz)", fontsize=14)
    plt.title(file_name, fontsize=14 )
    plt.tight_layout()
    plt.savefig("./cwt_figures/" + file_name + "_CWT" + ".png")
    # plt.show()



def gener_simul_data():
    fs = 1024
    t = np.arange(0, 1.0, 1.0 / fs)
    f1 = 100
    f2 = 200
    f3 = 300
    data = np.piecewise(t, [t<1, t<0.8, t<0.3],
                        [lambda t: np.sin(2 * np.pi * f1 * t),
                         lambda t: np.sin(2 * np.pi * f2 * t),
                         lambda t: np.sin(2 * np.pi * f3 * t)])
    return data

if __name__ == "__main__":
    print(pywt.families())
    print(pywt.wavelist('morl'))
    file_path = "../raw_data/Training_set/Bearing1_1/acc/"

    file_list = os.listdir(file_path)
    print("file_path:", file_path)
    print("num files:", len(file_list))

    for file_name in file_list[:]:
        file_dir = file_path + file_name
        csv_data = pd.read_csv(file_dir, header=None)  # 无表头的表格
        data_h = np.array(csv_data.iloc[0:, -2].tolist())
        data_v = np.array(csv_data.iloc[0:, -1].tolist())

        # data = gener_simul_data()

        CWT(data_h, fs=25600)