Evaluation of DistCurve on several datsets
from easydict import EasyDict
from dist_curve.curve_constructor import makeCurve
from dist_curve.model import getTrainedEstimator
import pickle
# model = getTrainedEstimator(weights_path="/data/dzeiberg/ClassPriorEstimation/model.hdf5")
model = getTrainedEstimator(weights_path="/home/dz/research/ClassPriorEstimation/model.hdf5")
from glob import glob
from tqdm.notebook import tqdm
import numpy as np
import matplotlib.pyplot as plt
def getDSAbsErrs(ds):
    absErrs = []
    for inst in ds.instances:
        curve = makeCurve(ds.instances[0].posScores.reshape((-1,1)),
                          ds.instances[0].unlabeledScores.reshape((-1,1))).reshape((1,-1))
        alphaHat = model.predict(curve/curve.sum())
        absErrs.append(np.abs(inst.alpha - alphaHat))
    return absErrs
model.summary()
files = glob("/data/dzeiberg/ClassPriorEstimation/processedDatasets_partial/*.pkl")
absErrs= []
for file in tqdm(files,total=len(files)):
    name = file.split("/")[-1].replace(".pkl","")
    ds = EasyDict(pickle.load(open(file,"rb")))
    aes = getDSAbsErrs(ds)
    absErrs.append(aes)
    print(name, "{:.3f}".format(np.mean(aes)))
print("Overall MAE: {:.3f}".format(np.mean(absErrs)))