import mat73

from scipy.io import loadmat,savemat
import h5py
from easydict import EasyDict
import numpy as np
from tqdm.notebook import tqdm

from dist_curve.transforms import getOptimalTransform

import os
import pickle
from glob import glob
os.sched_setaffinity(0,set(range(10)))
def getFile(dsPath):
    try:
        ds = loadmat(dsPath)
    except:
        ds= {}
        for k,v in h5py.File(dsPath,"r").items():
            ds[k] = np.array(v)
    return ds
def getPUDatasetInstance(data):
    # get indices of all positives and negatives
    posIdxs = np.where(data.y == 1)[0]
    negIdxs = np.where(data.y == 0)[0]
    if len(posIdxs) < 1000:
        numPos = 100
    else:
        numPos = 1000
    # Split Positive data into component and mixture
    numUnlabeledPos = len(posIdxs) - numPos
    unlabeledPosIdxs = np.random.choice(posIdxs,
                                     replace=False,size=numUnlabeledPos)
    posComponentIdxs = list(set(posIdxs) - set(unlabeledPosIdxs))
    posInstances = data.X[posComponentIdxs]
    # Downsample mixture if necessary
    if len(negIdxs) + len(unlabeledPosIdxs) > 10000:
        n0 = int(10000 * len(negIdxs)/(len(negIdxs) + len(unlabeledPosIdxs)))
        n1 = 10000 - n0
        unlabeledNegIdxs = np.random.choice(negIdxs, replace=False,size=n0)
        unlabeledPosIdxs = np.random.choice(unlabeledPosIdxs,replace=False,size=n1)
    else:
        unlabeledNegIdxs = negIdxs
    unlabeledInstances = data.X[np.concatenate((unlabeledPosIdxs, unlabeledNegIdxs))]
    
    hiddenLabels = np.concatenate((np.ones_like(unlabeledPosIdxs),
                                   np.zeros_like(unlabeledNegIdxs)))
    
    pu_instance = EasyDict({
        "positiveInstances": posInstances,
        "unlabeledInstances": unlabeledInstances,
        "hiddenLabels": hiddenLabels,
        "alpha": hiddenLabels.sum()/len(hiddenLabels)})
    return pu_instance
def addOptimalTransformToInstance(inst):
    p = inst.positiveInstances
    u = inst.unlabeledInstances
    x = np.concatenate((p,u))
    s = np.concatenate((np.ones(p.shape[0]),
                        np.zeros(u.shape[0])))
    probs, aucPU = getOptimalTransform(x,s)
    posScores = probs[:p.shape[0]]
    unlabeledScores = probs[p.shape[0] + 1:]
    out = EasyDict(inst)
    out.posScores = posScores
    out.unlabeledScores = unlabeledScores
    out.aucPU = aucPU
    return out
filenames = glob("/data/dzeiberg/ClassPriorEstimation/rawDatasets/*.mat")
for filename in tqdm(filenames,total=len(filenames),leave=False):
    dsname = filename.split("/")[-1].split(".")[0]
    if not os.path.isfile("/data/dzeiberg/ClassPriorEstimation/processedDatasets/{}.pkl".format(dsname)):
        ds = EasyDict(getFile(filename))
        ds.instances = []
        NInstances = 10
        for inst_num in tqdm(range(NInstances),total=NInstances,leave=False):
            inst = getPUDatasetInstance(ds)
            ds.instances.append(addOptimalTransformToInstance(inst))
            pickle.dump(ds,open("/data/dzeiberg/ClassPriorEstimation/processedDatasets/{}.pkl".format(dsname),"wb"))
            savemat("/data/dzeiberg/ClassPriorEstimation/processedDatasets/{}.mat".format(dsname),ds)
ds2 = EasyDict(pickle.load(open("/data/dzeiberg/ClassPriorEstimation/processedDatasets_partial/abalone.pkl","rb")))
ds2.instances[0].keys()