from __future__ import annotations
from functools import partial
from multiprocessing import Pool, cpu_count
from typing import Tuple, Union
import numpy as np
import pandas as pd
from libpysal.weights import W as libpysal_W
from .utils import pysal_to_pandas
def _hotspot_AI(
hotspot_series: pd.Series,
weight_df: pd.DataFrame,
knn: Union[pd.DataFrame, libpysal_W],
) -> Tuple[pd.Series, pd.Series]:
"""
Calculate Spatial Transcriptomics AI value for a single gene.
Parameters
==========
gene : str
Calculate which gene.
hotspot_df : pd.DataFrame
A hotspot DataFrame generated by svgbit.
weight_df : pd.DataFrame
Weight used by AI. In default, svgbit uses local Moran's I p value as weight.
knn : pd.DataFrame or libpysal.weights.W
KNN network used for neighbor idenfication.
Returns
=======
ai_series : pd.Series
A Series for AI value.
di_series : pd.Series
A Series for local Di value.
"""
try:
hotspot_series = hotspot_series.sparse.to_dense()
except AttributeError:
pass
try:
weight_df = weight_df.sparse.to_dense()
except AttributeError:
pass
if isinstance(knn, libpysal_W):
knn = pysal_to_pandas(knn)
knn.index = hotspot_series.index
knn.columns = hotspot_series.index
gene_hotspot_series = hotspot_series[hotspot_series != 0]
n_hotspots = len(gene_hotspot_series)
gene = hotspot_series.name
spots = hotspot_series.index
if n_hotspots == 0:
ai_series = pd.Series([0], index=[gene], name="AI")
di_series = pd.Series(
[0] * len(hotspot_series),
index=spots,
name=gene,
)
return ai_series, di_series
hotspots = gene_hotspot_series.index.tolist()
hs = {}
for i in hotspots:
hi_wnn_df = pd.DataFrame(knn[i], columns=[i])
hi_wnn_df = pd.DataFrame(hi_wnn_df[hi_wnn_df[i] == 1])
knn_coors = hi_wnn_df.index.tolist()
inter_spots = list(set(hotspots).intersection(set(knn_coors)))
hs[i] = ((weight_df[gene][inter_spots] /
weight_df[gene][knn_coors].sum()).sum())
di_array = np.array(list(hs.values()))
di_sum = di_array.sum()
di_mean = di_sum / n_hotspots
ai_series = pd.Series([di_mean], index=[gene], name="AI")
di_series = pd.Series(
hs,
name=gene,
).reindex(index=spots).fillna(0)
return ai_series, di_series
[docs]def hotspot_AI(
hotspot_df: pd.DataFrame,
weight_df: pd.DataFrame,
knn: Union[pd.DataFrame, libpysal_W],
cores: int = cpu_count(),
) -> Tuple[pd.Series, pd.DataFrame]:
"""
Calculate Spatial Transcriptomics AI value for all genes.
Parameters
==========
hotspot_df : pd.DataFrame
A hotspot DataFrame generated by svgbit.
weight_df : pd.DataFrame
Weight used by AI. In default, svgbit uses local Moran's I p value as weight.
knn : pd.DataFrame or libpysal.weights.W
KNN network used for neighbor idenfication.
cores : int
Number of threads to run svgbit. Use all available cpus by default.
Returns
=======
ai_series : pd.Series
A Series for AI value.
di_df : pd.DataFrame
A DataFrame for local Di value.
"""
if isinstance(knn, libpysal_W):
knn = pysal_to_pandas(knn)
knn.index = hotspot_df.index
knn.columns = hotspot_df.index
partial_func = partial(
_hotspot_AI,
weight_df=weight_df,
knn=knn,
)
pool = Pool(processes=cores)
result_lists = pool.map(
partial_func,
[hotspot_df[i] for i in hotspot_df.columns],
)
pool.close()
pool.join()
ai_series = pd.concat([i[0] for i in result_lists], axis=0)
di_df = pd.concat([i[1] for i in result_lists], axis=1)
return ai_series, di_df
if __name__ == "__main__":
pass