Source code for svgbit.core.combinations

from collections import Counter
from typing import Optional, Union

import numpy as np
import pandas as pd
from sklearn.neighbors import NearestNeighbors
from sklearn.mixture import BayesianGaussianMixture

from .STDataset import STDataset


def _find_combinations(
    hotspot_df: pd.DataFrame,
    coordinate_df: pd.DataFrame,
    spot_type: pd.DataFrame,
    svg_cluster: pd.Series,
    center_spots: Union[int, list],
    selected_genes: Optional[list] = None,
    use_neighbor: bool = True,
) -> pd.DataFrame:
    """
    Find gene pairs in certain SVG cluster.

    Parameters
    ==========
    hotspot_df : pd.DataFrame
        A hotspot DataFrame generated by svgbit.

    coordinate_df : pd.DataFrame
        A pd.DataFrame for coordinate files.

    spot_type : pd.DataFrame
        A pd.DataFrame for assigned type_df.

    center_spots : int or list
        If a ``int`` is given, find gene pairs for this SVG cluster. If a
        ``list`` of spots is given, find gene pairs within those spots.
        ``selected_genes`` should be given if ``center_cluster`` is a ``list``.

    selected_genes : list or None, default None
        If a ``list`` of genes is given, find gene pairs within given genes.

    use_neighbor : bool, default True
        Whether to find gene pairs with neighbor SVG clusters.

    Returns
    =======
    gene_pairs_df : pd.DataFrame
        A pd.DataFrame for gene pairs in center_cluster with weights.
    """
    try:
        center_spots = list(center_spots)
    except TypeError:
        try:
            center_spots = int(center_spots)
        except TypeError:
            err = (f"center_spots must be a number or a list-like object, "
                   "not '{type(center_spots)}'")
            raise TypeError(err)

    if use_neighbor and (not isinstance(center_spots, list)):
        # find neighbor clusters
        certain_series = spot_type[spot_type["spot_type"] != "uncertain"]
        certain_series = certain_series["type_1"]
        nbrs = NearestNeighbors(n_neighbors=7).fit(coordinate_df)
        neighbor_clusters_ = {}
        for present_cluster in set(spot_type["type_1"]):
            used_spots = []
            neighbor_clusters_[present_cluster] = []
            present_spots = certain_series[certain_series ==
                                           present_cluster].index
            distances, indices = nbrs.kneighbors(
                coordinate_df.reindex(index=present_spots))
            for record in indices:
                for neighbor_spot in record[1:]:
                    if neighbor_spot in used_spots:
                        continue
                    used_spots.append(neighbor_spot)
                    neighbor_cluster = spot_type.iloc[neighbor_spot, 1]
                    if neighbor_cluster == present_cluster:
                        continue
                    neighbor_clusters_[present_cluster].append(
                        neighbor_cluster)

        neighbor_clusters = {}
        for i in neighbor_clusters_:
            counter = Counter(neighbor_clusters_[i])
            coverage_ = len(certain_series[certain_series == i]) * 0.03
            neighbor_clusters[i] = [
                j[0] for j in counter.most_common()[:3] if j[1] >= coverage_
            ]
            neighbor_clusters[i].append(i)

        selected_spots = [
            j for i in neighbor_clusters[center_spots]
            for j in spot_type[spot_type["type_1"] == i].index
        ]
        selected_genes = [
            j for i in neighbor_clusters[center_spots]
            for j in svg_cluster[svg_cluster == i].index
        ]
    else:
        if isinstance(center_spots, list):
            selected_spots = center_spots
            if len(center_spots) <= 10:
                center_spots = ", ".join(center_spots)
            else:
                center_spots = ", ".join(center_spots[:3])
                center_spots = center_spots + "..."
        else:
            selected_spots = spot_type[spot_type["type_1"] ==
                                       center_spots].index
            if selected_genes is None:
                selected_genes = svg_cluster[svg_cluster == center_spots].index

    count_sub = hotspot_df.reindex(
        index=selected_spots,
        columns=selected_genes,
    )
    con_matrix = count_sub.T @ count_sub
    n_hotspots = np.diag(con_matrix.to_numpy())
    con_matrix = con_matrix - np.diag(n_hotspots)
    men_matrix = (1 - count_sub.T) @ count_sub

    con_ratio = con_matrix / len(selected_spots)
    men_ratio = men_matrix / len(selected_spots)

    gene_pairs_df = pd.DataFrame(columns=[
        "SVG_cluster", "gene_1", "gene_2", "colocalization_score",
        "exclusive_score"
    ])
    line = 0
    for gene_1 in con_ratio.columns:
        for gene_2 in con_ratio.index:
            if gene_1 == gene_2:
                continue
            write_dict = {
                "SVG_cluster": center_spots,
                "gene_1": gene_1,
                "gene_2": gene_2,
                "colocalization_score": con_ratio[gene_1][gene_2],
                "exclusive_score": men_ratio[gene_1][gene_2],
            }
            write_series = pd.Series(write_dict, name=line).to_frame().T
            line += 1
            gene_pairs_df = pd.concat([gene_pairs_df, write_series])

    gene_pairs_df.fillna(0, inplace=True)

    gmm_con = BayesianGaussianMixture(
        n_components=3,
        max_iter=500,
        n_init=5,
    )
    result_con = gmm_con.fit_predict(
        gene_pairs_df["colocalization_score"].to_numpy().reshape(-1, 1))
    rank_str = ["low", "middle", "high"]
    mean_argsort = gmm_con.means_.argsort(axis=0).flatten()
    rank_dict = {i: j for i, j in zip(mean_argsort, rank_str)}
    result_con = [rank_dict[i] for i in result_con]
    gene_pairs_df["colocalization_degree"] = result_con

    gmm_men = BayesianGaussianMixture(
        n_components=3,
        max_iter=500,
        n_init=5,
    )
    result_men = gmm_men.fit_predict(
        gene_pairs_df["exclusive_score"].to_numpy().reshape(-1, 1))
    mean_argsort = gmm_men.means_.argsort(axis=0).flatten()
    rank_dict = {i: j for i, j in zip(mean_argsort, rank_str)}
    result_men = [rank_dict[i] for i in result_men]
    gene_pairs_df["exclusive_degree"] = result_men

    return gene_pairs_df


[docs]def find_combinations( dataset: STDataset, center_spots: Union[int, list], selected_genes: Optional[list] = None, use_neighbor: bool = True, ) -> pd.DataFrame: """ Find gene pairs in certain SVG cluster. Parameters ========== dataset : STDataset A STDataset with all steps finished. center_spots : int or list If a ``int`` is given, find gene pairs for this SVG cluster. If a ``list`` of spots is given, find gene pairs within those spots. ``selected_genes`` should be given if ``center_cluster`` is a ``list``. selected_genes : list or None, default None If a ``list`` of genes is given, find gene pairs within given genes. use_neighbor : bool, default True Whether to find gene pairs with neighbor SVG clusters. Returns ======= gene_pairs_df : pd.DataFrame A pd.DataFrame for gene pairs in center_cluster with weights. """ return _find_combinations( dataset.hotspot_df, dataset.coordinate_df, dataset.spot_type, dataset.svg_cluster, center_spots, selected_genes, use_neighbor, )