from glob import glob

import scipy.stats as ss
import matplotlib.pyplot as plt
import numpy as np
from easydict import EasyDict
from tqdm.notebook import tqdm

from dist_curve.curve_constructor import makeCurve
makeCurve??
def getUnlabeled(aNeg,bNeg,aPos,bPos,alpha,size_mixture):
    unlabeledDataset = []
    componentAssignment = []
    for s in range(int(size_mixture)):
        distributionAssignment = np.random.binomial(1, alpha)
        if not distributionAssignment:
            sample = np.random.beta(aNeg,bNeg)
        else:
            sample = np.random.beta(aPos,bPos)
        unlabeledDataset.append(sample)
        componentAssignment.append(distributionAssignment)
    return np.array(unlabeledDataset).reshape((-1,1)), np.array(componentAssignment).reshape((-1,1))
def sampleData(file):
    with open(file) as f:
        aNeg,bNeg,aPos,bPos = [float(i) for i in f.read().split(",")]
    MIXTURE_SIZE_LOW, MIXTURE_SIZE_HIGH = 1000, 10000
    # Parameters of prior uniform distribution from which
    # positive sample size is drawn
    POSITIVE_SIZE_LOW, POSITIVE_SIZE_HIGH = 100, 5000
    # range of possible alphas
    ALPHA_LOW, ALPHA_HIGH = 0.01, 1
    size_mixture = np.random.randint(MIXTURE_SIZE_LOW, MIXTURE_SIZE_HIGH)
    size_positive = np.random.randint(POSITIVE_SIZE_LOW, POSITIVE_SIZE_HIGH)
    data = EasyDict()
    data.alpha = np.random.uniform(ALPHA_LOW, ALPHA_HIGH)
    data.unlabeled, data.hiddenMixtureLabels = getUnlabeled(aNeg,bNeg,
                                                            aPos,bPos,
                                                            data.alpha,
                                                            size_mixture)
    data.positive = np.random.beta(aPos,bPos,size=size_positive).reshape((-1,1))
    data.aNeg = aNeg
    data.bNeg = bNeg
    data.aPos = aPos
    data.bPos = bPos
    return data
paramFiles = glob("/ssdata/ClassPriorEstimationPrivate/trainParamGroups/*/param*.csv")
np.random.shuffle(paramFiles)
def minmax(c):
    minC = np.min(c)
    maxC = np.max(c)
    return (c - minC) / (maxC - minC)
NRepsPerSet = 2
quantiles = np.arange(0,1,.001)
curves = np.zeros((len(paramFiles) * NRepsPerSet,
                  len(quantiles)))
data = []
for fn, file in tqdm(enumerate(paramFiles),total=len(paramFiles)):
    for rep in tqdm(range(NRepsPerSet),leave=False):
        d = sampleData(file)
#         print(d.positive.shape, d.unlabeled.shape)
        curve = makeCurve(d.positive, d.unlabeled,
                          num_curves_to_average=25,
                          quantiles=quantiles)
        curves[fn * NRepsPerSet + rep] = minmax(curve)
        data.append(d)
r = np.random.choice(np.arange(fn * NRepsPerSet + rep))

plt.plot(curves[r])
plt.vlines(data[r].alpha * 1000,0,1)
plt.show()

# plt.plot(np.arange(0,1,.01), ss.beta.pdf(np.arange(0,1,.01),a=data[r].aNeg, b=data[r].bNeg))
plt.plot(np.arange(0,1,.01), ss.beta.pdf(np.arange(0,1,.01),a=data[r].aPos, b=data[r].bPos))
plt.plot(np.arange(0,1,.01),
         data[r].alpha * ss.beta.pdf(np.arange(0,1,.01),
                                     a=data[r].aNeg,
                                     b=data[r].bNeg) + (1 - data[r].alpha) * ss.beta.pdf(np.arange(0,1,.01),
                                                                                         a=data[r].aPos,
                                                                                         b=data[r].bPos))