# state estimation with poisson convex mixture model
from uncurl.clustering import kmeans_pp
from uncurl.state_estimation import initialize_from_assignments, nolips_update_w, _create_w_objective, _create_m_objective
import numpy as np
from scipy.optimize import minimize
from scipy.special import gammaln
eps=1e-10
def _poisson_calculate_lls(X, M, W, use_constant=True, add_eps=True):
    """
    For hard thresholding: this calculates the log-likelihood of each
    gene, and returns a list of log-likelihoods.
    """
    genes, cells = X.shape
    L = np.zeros(genes)
    d = M.dot(W)
    if add_eps:
        d += 1e-30
    # d[d==0] = np.min(d[d>0])/1e4
    LLs = X*np.log(d) - d
    if use_constant:
        LLs -= gammaln(X+1)
    L = np.sum(LLs, 1)
    L[np.isnan(L)] = -np.inf
    return L
[docs]def robust_estimate_state(data, clusters, dist='Poiss', init_means=None, init_weights=None, method='NoLips', max_iters=10, tol=1e-10, disp=True, inner_max_iters=100, reps=1, normalize=True, gene_portion=0.2, use_constant=True):
    """
    Uses a Poisson Covex Mixture model to estimate cell states and
    cell state mixing weights.
    Args:
        data (array): genes x cells
        clusters (int): number of mixture components
        dist (string, optional): Distribution used - only 'Poiss' is implemented. Default: 'Poiss'
        init_means (array, optional): initial centers - genes x clusters. Default: kmeans++ initializations
        init_weights (array, optional): initial weights - clusters x cells, or assignments as produced by clustering. Default: random(0,1)
        method (str, optional): optimization method. Options include 'NoLips' or 'L-BFGS-B'. Default: 'NoLips'.
        max_iters (int, optional): maximum number of iterations. Default: 10
        tol (float, optional): if both M and W change by less than tol, then the iteration is stopped. Default: 1e-10
        disp (bool, optional): whether or not to display optimization parameters. Default: True
        inner_max_iters (int, optional): Number of iterations to run in the optimization subroutine for M and W. Default: 100
        normalize (bool, optional): True if the resulting W should sum to 1 for each cell. Default: True.
        gene_portion (float, optional): The proportion of genes to use for estimating W after hard thresholding. Default: 0.2
    Returns:
        M (array): genes x clusters - state means
        W (array): clusters x cells - state mixing components for each cell
        ll (float): final log-likelihood
        genes (array): 1d array of all genes used in final iteration.
    """
    genes, cells = data.shape
    if init_means is None:
        means, assignments = kmeans_pp(data, clusters)
    else:
        means = init_means.copy()
    clusters = means.shape[1]
    w_init = np.random.random((clusters, cells))
    if init_weights is not None:
        if len(init_weights.shape)==1:
            init_weights = initialize_from_assignments(init_weights, clusters)
    # repeat steps 1 and 2 until convergence:
    ll = np.inf
    # objective functions...
    w_obj = _create_w_objective
    m_obj = _create_m_objective
    ll_func = _poisson_calculate_lls
    included_genes = np.arange(genes)
    num_genes = int(np.ceil(gene_portion*genes))
    if disp:
        print('num_genes: {0}'.format(num_genes))
    nolips_iters = inner_max_iters
    X = data.astype(float)
    Xsum = X.sum(0)
    Xsum_m = X.sum(1)
    for i in range(max_iters):
        if disp:
            print('iter: {0}'.format(i))
        # step 1: given M, estimate W
        w_objective = w_obj(means[included_genes,:], data[included_genes,:])
        if method == 'NoLips':
            for j in range(nolips_iters):
                w_new = nolips_update_w(X[included_genes,:], means[included_genes,:], w_init, Xsum)
                #w_new = w_res.x.reshape((clusters, cells))
                #w_new = w_new/w_new.sum(0)
                w_init = w_new
        elif method == 'L-BFGS-B':
            w_bounds = [(0, None) for c in range(clusters*cells)]
            w_res = minimize(w_objective, w_init.flatten(),
                    method='L-BFGS-B', jac=True, bounds=w_bounds,
                    options={'disp':disp, 'maxiter':inner_max_iters})
            w_new = w_res.x.reshape((clusters, cells))
            w_init = w_new
        w_ll, w_deriv = w_objective(w_new.reshape(clusters*cells))
        #w_diff = np.sqrt(np.sum((w_res.x-w_init)**2))/w_init.size
        #w_init = w_res.x
        #w_new = w_res.x.reshape((clusters, cells))
        # step 2: given W, update M
        w_ll, w_deriv = w_objective(w_new.reshape(clusters*cells))
        if disp:
            print('Finished updating W. Objective value: {0}'.format(w_ll))
        # step 2: given W, update M
        m_objective = m_obj(w_new, data[included_genes,:])
        if method == 'NoLips':
            for j in range(nolips_iters):
                m_new = nolips_update_w(X[included_genes,:].T, w_new.T, means[included_genes,:].T, Xsum_m)
                means[included_genes,:] = m_new.T
        elif method == 'L-BFGS-B':
            m_bounds = [(0, None) for c in range(clusters*len(included_genes))]
            m_res = minimize(m_objective, means[included_genes,:].flatten(),
                    method='L-BFGS-B', jac=True, bounds=m_bounds,
                    options={'disp':disp, 'maxiter':inner_max_iters})
            means[included_genes,:] = m_res.x.reshape((len(included_genes), clusters))
        m_ll, m_deriv = m_objective(means[included_genes,:].reshape(len(included_genes)*clusters))
        if disp:
            print('Finished updating M. Objective value: {0}'.format(m_ll))
        # step 3: hard thresholding/gene subset selection
        lls = ll_func(data, means, w_new, use_constant)
        if i < max_iters - 1:
            included_genes = lls.argsort()[::-1][:num_genes]
            if disp:
                print(lls[included_genes])
                print(included_genes)
                print('selected number of genes: ' + str(len(included_genes)))
    if normalize:
        w_new = w_new/w_new.sum(0)
    return means, w_new, m_ll, included_genes