import numpy as np
import pandas as pd


def createTable_datasets(name= "./Results/results_exper_QDA_lr0.1.csv"):
    list_to_print= list()
    df= pd.read_csv(name)
    df= df[['data', 'm', 'n']].copy()
    df= df.sort_values(by=["data"])
    df.drop_duplicates(inplace=True)
    df.reset_index(drop=True, inplace=True)
    df.index += 1
    df.insert(0, 'index', df.index)

    print("\\begin{table}[h]\n\centering")
    print(df.to_latex(float_format=lambda x: '{:.3f}'.format(x), index = False))
    print("\caption{Data sets. \"Index\" column contains the identifier used in the tables with empirical results, "
          "and the columns \"$m$\" and \"$n$\" correspond to the number of instances and variables of each dataset,"
          "respectively.}\n\end{table}")

def createTable_minVals_reach(classif= "LR", name= "./Results/results_exper_LR_lr0.1.csv",  algorithms= ["RC", "GD"],
                              print_index= True, print_iter= True, print_reach= True, print_average= True):

    score_to_print = "0-1"

    df= pd.read_csv(name)
    df= df.sort_values(by=["data","iter"])

    for type in ["ML","MAP"]:
        list_to_print= list()

        df_type= df[(df["type"]== type) & (df["loss"] == score_to_print)]

        ind= 0
        avg_error_init= list()
        avg_error_1= list()
        avg_iter_1= list()
        avg_error_2= list()
        avg_iter_2= list()
        avg_reach= list()

        for data, group in df_type.groupby(["data"]):
            ind+= 1
            try:
                row=list()
                if print_index:
                    row.append(ind)
                else:
                    row.append(data)
                init_val= group[group['iter']==1]['val'].values[0]
                row.append('{:.3f}'.format(init_val))
                avg_error_init.append(init_val)

                # get best values
                df_alg_1 = group[(group['alg'] == algorithms[0])]
                # Find the minimum value(s) of the "val" column
                min_val_1 = df_alg_1['val'].min()
                min_iter_1 = df_alg_1[df_alg_1['val'] == min_val_1]['iter'].min()
                avg_error_1.append(min_val_1)
                avg_iter_1.append(min_iter_1)

                df_alg_2 = group[(group['alg'] == algorithms[1])]
                # Find the minimum value(s) of the "val" column
                min_val_2 = df_alg_2['val'].min()
                min_iter_2 = df_alg_2[df_alg_2['val'] == min_val_2]['iter'].min()
                avg_error_2.append(min_val_2)
                avg_iter_2.append(min_iter_2)

                # highlight in bold best values
                if min_val_1< min_val_2:
                    row.append(r"ttextbf{"+str('{:.3f}'.format(min_val_1))+"}t")
                else:
                    row.append(str('{:.3f}'.format(min_val_1)))

                if print_iter:
                    row.append(min_iter_1)

                if min_val_2< min_val_1:
                    row.append('ttextbf{'+str('{:.3f}'.format(min_val_2))+"}t")
                else:
                    row.append(str('{:.3f}'.format(min_val_2)))

                if print_iter:
                    row.append(min_iter_2)

                # Get the number of iterations in which the first algorithm at least equals the second in terms of error
                if print_reach:
                    for iter in range(1,np.max(df_alg_1["iter"])+1):
                        if (df_alg_1[(df_alg_1["iter"] == iter) & (df_alg_1["loss"] == score_to_print)]["val"].values[0]
                                <= min_val_2):
                            row.append(iter)
                            avg_reach.append(iter)
                            break
                        if iter== np.max(df_alg_1["iter"]):
                            row.append("-")

                list_to_print.append(row)
            except Exception as e:
                # Handling the exception by printing its description
                print(f"Exception {e} in data {data}")

        if print_average:
            list_to_print.append(["avg.", str('{:.3f}'.format(np.average(avg_error_init))),
                                              str('{:.3f}'.format(np.average(avg_error_1))),
                                  str('{:.0f}'.format(np.average(avg_iter_1))),
                                  str('{:.3f}'.format(np.average(avg_error_2))),
                                  str('{:.0f}'.format(np.average(avg_iter_2))),
                                     str('{:.0f}'.format(np.average(avg_reach)))])

        columns= ["Data", type]
        columns += [algorithms[0]]
        if print_iter:
            columns+= ["Iter"]
        columns += [algorithms[1]]
        if print_iter:
            columns+= ["Iter"]
        if print_reach:
            columns+= ["Reach"]
        df_to_print= pd.DataFrame(list_to_print, columns=columns)
        print("\\begin{table}[t]\n\centering")
        print(df_to_print.to_latex(index = False))
        print("\caption{Empirical error of" + f" {classif} with {type}. The column {type} shows the error "
                                    f"for the {type} learning algorithm. The columns {algorithms[0]} and {algorithms[1]} "
                                    f"show the minimum errors obtained, and the columns Iter the number of iterations "
                                    f"required to reach the minimum." + "}\n\end{table}\n\n")



if __name__ == '__main__':
    createTable_datasets()
    createTable_minVals_reach(classif="NB", name="./Results/results_exper_NB_lr0.1.csv", algorithms=["RC", "GD"])
    createTable_minVals_reach(classif="QDA", name="./Results/results_exper_QDA_lr0.1.csv", algorithms=["RC","GD"])
    createTable_minVals_reach(classif="LR", name="./Results/results_exper_LR_lr0.1.csv", algorithms=["RC", "GD"])
    createTable_minVals_reach(classif="NB", name="./Results/results_exper_NB_lr0.1_DFE.csv", algorithms=["RC", "DFE"])
    createTable_minVals_reach(classif="LR", name="./Results/results_exper_LR_lr0.1_Unif.csv", algorithms=["RC", "GD"])