Evaluation of DistCurve on several datsets
from easydict import EasyDict
from dist_curve.curve_constructor import makeCurve
from dist_curve.model import getTrainedEstimator
import pickle
# model = getTrainedEstimator(weights_path="/data/dzeiberg/ClassPriorEstimation/model.hdf5")
model = getTrainedEstimator(weights_path="/home/dz/research/ClassPriorEstimation/model.hdf5")
from glob import glob
from tqdm.notebook import tqdm
import numpy as np
import matplotlib.pyplot as plt
def getDSAbsErrs(ds):
absErrs = []
for inst in ds.instances:
curve = makeCurve(ds.instances[0].posScores.reshape((-1,1)),
ds.instances[0].unlabeledScores.reshape((-1,1))).reshape((1,-1))
alphaHat = model.predict(curve/curve.sum())
absErrs.append(np.abs(inst.alpha - alphaHat))
return absErrs
model.summary()
files = glob("/data/dzeiberg/ClassPriorEstimation/processedDatasets_partial/*.pkl")
absErrs= []
for file in tqdm(files,total=len(files)):
name = file.split("/")[-1].replace(".pkl","")
ds = EasyDict(pickle.load(open(file,"rb")))
aes = getDSAbsErrs(ds)
absErrs.append(aes)
print(name, "{:.3f}".format(np.mean(aes)))
print("Overall MAE: {:.3f}".format(np.mean(absErrs)))