import argparse

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import GridSearchCV
from sklearn.pipeline import Pipeline
from sklearn.ensemble import RandomForestRegressor
from sklearn.linear_model import Lasso
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score

import matplotlib.pyplot as plt
import pickle as pkl


def get_model(class_name):

    if class_name == "Lasso":
        model = Lasso()
        params = {
            "Estimator__alpha": [0.1, 1.0, 10.0],
        }
    elif class_name == "RandomForestRegressor":
        model = RandomForestRegressor(n_jobs=-1)
        params = {
            "Estimator__n_estimators": [10, 50, 100, ],
            "Estimator__max_depth": [10, 30, 50, ],
            "Estimator__max_features": ["sqrt", "log2"],
        }

    return model, params


def train(path_train, path_test, class_name):

    label = "ms_prediction"

    objective = "Ms"
    explanatory = ["C", "Mn", "Ni", "Cr", "Mo", "Si"]

    estimator, param_grid = get_model(class_name=class_name)

    X_train, y_train = [pd.read_csv(path_train).loc[:, col] for col in [explanatory, objective]]
    X_test, y_test = [pd.read_csv(path_test).loc[:, col] for col in [explanatory, objective]]

    # 学習
    train = Pipeline(steps=[
        ("StandardScaler", StandardScaler()),
        ("Estimator", estimator),
    ])

    model = GridSearchCV(train, cv=10, param_grid=param_grid, scoring="neg_mean_squared_error")
    model.fit(X_train, y_train)

    with open(label + ".pkl", "wb") as f:
        pkl.dump(model, f)

    # 評価

    train_score = {
        "Data": "Train",
        "MAE": mean_absolute_error(y_train, model.predict(X_train)),
        "RMSE": np.sqrt(mean_squared_error(y_train, model.predict(X_train))),
        "R2": r2_score(y_train, model.predict(X_train))
    }

    patent_score = {
        "Data": "Test",
        "MAE": mean_absolute_error(y_test, model.predict(X_test)),
        "RMSE": np.sqrt(mean_squared_error(y_test, model.predict(X_test))),
        "R2": r2_score(y_test, model.predict(X_test))
    }

    df_score = round(pd.DataFrame([train_score, patent_score]), 2)

    df_train_result = pd.DataFrame({
        "Obs": y_train.values.flatten(),
        "Pred": model.predict(X_train)
    }).merge(X_train, left_index=True, right_index=True).assign(Label="Train")

    df_test_result = pd.DataFrame({
        "Obs": y_test.values.flatten(),
        "Pred": model.predict(X_test)
    }).merge(X_test, left_index=True, right_index=True).assign(Label="Test")

    df_result = pd.concat([df_train_result, df_test_result])
    df_result.to_csv(f"{label}.csv", index=False)

    fig = plt.figure(figsize=(6, 6))
    ax = fig.add_subplot(1, 1, 1)
    ax.scatter(y_train, model.predict(X_train), color="dimgray", alpha=0.8, s=4, label="Train")
    ax.scatter(y_test, model.predict(X_test), color="#00bfc4", alpha=0.8, s=4, label="Test")
    ax.axline([0, 0], [1, 1], color="black", linewidth=1, alpha=0.6)
    ax.grid(linestyle="--", alpha=0.5)
    ax.set_ylim([300, 600])
    ax.set_xlim([300, 600])
    ax.table(
        colWidths=[0.1] * 4,
        colLabels=df_score.columns.to_list(),
        cellText=df_score.values.tolist(),
        loc="lower right"
    )
    ax.legend(bbox_to_anchor=(0, 1), loc='upper left', borderaxespad=1)
    ax.set_title(label)
    ax.set_xlabel(f"Obs {objective}")
    ax.set_ylabel(f"Pred {objective}")
    ax.set_aspect("equal")
    # plt.show()
    plt.savefig(label + ".svg")
    plt.close()


def main():

    parser = argparse.ArgumentParser(description="")
    parser.add_argument("--path_train", type=str, required=True)
    parser.add_argument("--path_test", type=str, required=True)
    parser.add_argument("--method_name", type=str, required=True)
    args = parser.parse_args()

    train(args.path_train, args.path_test, args.method_name)


if __name__ == "__main__":
    main()
