Skip to content

Adjacency Simplex module

A class to process a GeoDataFrame, filter and sort it based on a variable, compute adjacency relationships, and form a simplicial complex.

Source code in spatial_tda/adjacency_simplex.py
class AdjacencySimplex:
    """
    A class to process a GeoDataFrame, filter and sort it based on a variable, 
    compute adjacency relationships, and form a simplicial complex.
    """

    def __init__(self, geo_dataframe, variable, threshold=None, filter_method='up'):
        """
        Initialize with a GeoDataFrame.

        Parameters:
        - gdf: GeoDataFrame containing geographic and attribute data.
        - variable: Column name used for filtering and sorting.
        - threshold: Tuple (min, max) for filtering values within a range.
        - filter_method: Sorting method, either 'up' (descending) or 'down' (ascending).
        """
        self.gdf = geo_dataframe
        self.variable = variable
        self.filter_method = filter_method
        self.threshold = threshold
        self.filtered_df = None
        self.adjacent_counties_dict = None
        self.merged_df = None
        self.simplicial_complex = None

    def filter_sort_gdf(self,return_original=False,return_filtered=False):
        """
        Filter and sort the GeoDataFrame based on the specified variable and method.
        """
        gdf = self.gdf.copy()

        # Sort the DataFrame based on the specified method
        if self.filter_method == 'up':
            gdf = gdf.sort_values(by=self.variable, ascending=True)
        elif self.filter_method == 'down':
            # get the max value
            max_value = gdf[self.variable].max()
            # invert the values - Assuming negative values are not present
            gdf[self.variable] = max_value - gdf[self.variable]
            gdf = gdf.sort_values(by=self.variable, ascending=True)
        else:
            raise ValueError("Invalid filter method. Use 'up' or 'down'.")

        # this need to be done before filtering
        gdf['sortedID'] = range(len(gdf))

        # this for the below filter
        filtered_df = gdf.copy()

        # Apply threshold filtering if specified
        if self.threshold:
            filtered_df = filtered_df[(filtered_df[self.variable] >= self.threshold[0]) &
                                      (filtered_df[self.variable] <= self.threshold[1])]

        # Convert DataFrame to GeoDataFrame
        filtered_df = gpd.GeoDataFrame(filtered_df, geometry='geometry')

        # Set Coordinate Reference System (CRS)
        filtered_df.crs = "EPSG:4326"

        self.filtered_df = filtered_df

        # this returns a filtered dataframe and the original dataframe with the sortedID
        if return_original and return_filtered:
            return gdf, filtered_df
        elif return_filtered:
            return filtered_df
        elif return_original:
            return gdf


    def calculate_adjacent_countries(self):
        """
        Compute adjacency relationships between geographic entities.
        """
        # Ensure filter_sort_gdf() has been executed
        if not hasattr(self, 'filtered_df') or not isinstance(self.filtered_df, gpd.GeoDataFrame):
            raise ValueError("Run filter_sort_gdf() before calling this method.")

        # Perform spatial join to find adjacent entities
        adjacent_entities = gpd.sjoin(self.filtered_df, self.filtered_df, predicate='intersects', how='left')

        # Remove self-intersections
        adjacent_entities = adjacent_entities.query('sortedID_left != sortedID_right')

        # Group by entity and store adjacent entities in a list
        adjacent_entities = adjacent_entities.groupby('sortedID_left')['sortedID_right'].apply(list).reset_index()
        adjacent_entities.rename(columns={'sortedID_left': 'county', 'sortedID_right': 'adjacent'}, inplace=True)

        # Create adjacency dictionary
        adjacent_dict = dict(zip(adjacent_entities['county'], adjacent_entities['adjacent']))

        # Merge adjacency information with the original dataset
        merged_df = pd.merge(adjacent_entities, self.filtered_df, left_on='county', right_on='sortedID', how='left')

        # Convert to GeoDataFrame
        merged_df = gpd.GeoDataFrame(merged_df, geometry='geometry')
        merged_df.crs = "EPSG:4326"

        # Store results
        self.adjacent_counties_dict = adjacent_dict
        self.merged_df = merged_df

    def form_simplicial_complex(self,return_simplicial_complex=False):
        """
        Construct a simplicial complex using adjacency relationships.
        """
        if not hasattr(self, 'adjacent_counties_dict'):
            raise ValueError("Run calculate_adjacent_countries() before calling this method.")

        max_dimension = 3  # Define maximum dimension for the simplicial complex
        simplicial_complex = invr.incremental_vr([], self.adjacent_counties_dict, max_dimension, list(self.adjacent_counties_dict.keys()))

        self.simplicial_complex = simplicial_complex

        if return_simplicial_complex:
            return simplicial_complex


    def compute_persistence(self, summaries=None):
        """
        Compute persistence diagrams for the simplicial complex and return selected topological summaries.

        :param summaries: List of summary names to return (e.g., ["H0", "TL", "AL"]). If None, return all.
        :return: Dictionary with requested summaries.
        """

        st = gudhi.SimplexTree()
        st.set_dimension(2)

        for simplex in self.simplicial_complex:
            if len(simplex) == 1:
                st.insert([simplex[0]], filtration=0.0)

        for simplex in self.simplicial_complex:
            if len(simplex) == 2:
                last_simplex = simplex[-1]
                filtration_value = self.filtered_df.loc[
                    self.filtered_df['sortedID'] == last_simplex, self.variable
                ].values[0]
                st.insert(simplex, filtration=filtration_value)

        for simplex in self.simplicial_complex:
            if len(simplex) == 3:
                last_simplex = simplex[-1]
                filtration_value = self.filtered_df.loc[
                    self.filtered_df['sortedID'] == last_simplex, self.variable
                ].values[0]
                st.insert(simplex, filtration=filtration_value)

        st.compute_persistence()
        persistence = st.persistence()

        intervals_dim0 = st.persistence_intervals_in_dimension(0)

        # Replace infinity with the max variable value
        max_value = self.filtered_df[self.variable].max()
        intervals_dim0[:, 1][np.isinf(intervals_dim0[:, 1])] = max_value

        # Compute topological summaries
        H0_data_points = len(intervals_dim0)
        TL = sum(interval[1] - interval[0] for interval in intervals_dim0)
        TML = sum((interval[1] + interval[0]) / 2 for interval in intervals_dim0)

        AL = TL / len(intervals_dim0) if len(intervals_dim0) > 0 else 0
        AML = TML / len(intervals_dim0) if len(intervals_dim0) > 0 else 0

        # Store results in a dictionary
        results = {
            "H0": H0_data_points,
            "TL": TL,
            "AL": AL,
            "TML": TML,
            "AML": AML,
        }

        # Return only requested summaries
        if summaries:
            return {key: results[key] for key in summaries if key in results}
        return results  # Default: return all summaries

    @staticmethod
    def fig2img(fig):
        """
        Convert a Matplotlib figure to a PIL Image.

        Parameters:
        - fig: A Matplotlib figure.

        Returns:
        - A PIL Image.
        """
        buf = io.BytesIO()
        fig.savefig(buf, bbox_inches='tight', pad_inches=0)
        buf.seek(0)
        img = Image.open(buf)
        return img

    def plot_simplicial_complex(self, save_dir=None):
        """
        Plot the simplicial complex and create a GIF animation showing its incremental construction.

        For each frame, the base map (self.gdf) is plotted along with labels, then
        all simplices up to that frame are drawn.

        Parameters:
        - save_dir: Directory path to save the GIF. If None, saves in the current directory.
        """
        if self.simplicial_complex is None:
            raise ValueError("Run form_simplicial_complex() before calling this method.")

        # Precompute centroids from filtered_df for plotting edges/triangles.
        city_coordinates = {
            row['sortedID']: np.array((row['geometry'].centroid.x, row['geometry'].centroid.y))
            for _, row in self.filtered_df.iterrows()
        }

        # Create a figure and axis
        fig, ax = plt.subplots(figsize=(10, 10))
        ax.set_axis_off() 

        # Plot the original GeoDataFrame without any filtration
        self.gdf.plot(ax=ax, edgecolor='black', linewidth=0.3, color="white")

        # Plot the centroid of the large square with values
        for _, row in self.gdf.iterrows():
            centroid = row['geometry'].centroid
            text_to_display = f"{row[self.variable]:.2f}"
            plt.text(centroid.x, centroid.y, text_to_display, fontsize=7, ha='center', color="black")


        frames = []
        for edge_or_triangle in self.simplicial_complex:

            # color sub regions based on how it enter the simplcial complex in adjacency method
            if len(edge_or_triangle) == 1:

                vertex = edge_or_triangle[0]
                # geometry = self.filtered_df.iterrows().loc[self.filtered_df.iterrows()['sortedID'] == vertex, 'geometry'].values[0]
                geometry = self.filtered_df[self.filtered_df['sortedID'] == vertex]['geometry'].values[0]
                ax.add_patch(Polygon(np.array(geometry.exterior.coords), closed=True, color='orange', alpha=0.3))
                img = self.fig2img(fig)
                frames.append(img)

            elif len(edge_or_triangle) == 2:
                # Plot an edge
                ax.plot(*zip(*[city_coordinates[vertex] for vertex in edge_or_triangle]), color='red', linewidth=2)
                img = self.fig2img(fig)
                frames.append(img)
            elif len(edge_or_triangle) == 3:
                # Plot a triangle
                ax.add_patch(plt.Polygon([city_coordinates[vertex] for vertex in edge_or_triangle], color='green', alpha=0.2))
                img = self.fig2img(fig)
                frames.append(img)

            #can change above code block
            plt.close(fig)


        # Define the GIF filename.
        gif_filename = f'adj_simplex_{self.variable}_{self.filter_method}.gif'

        if save_dir:
            gif_filename = f'{save_dir}/{gif_filename}'
        # Save the frames as a GIF.
        frames[0].save(gif_filename, save_all=True, append_images=frames[1:],
                       optimize=False, duration=600, loop=0)
        print(f"GIF created and saved as {gif_filename}.")

__init__(self, geo_dataframe, variable, threshold=None, filter_method='up') special

Initialize with a GeoDataFrame.

  • gdf: GeoDataFrame containing geographic and attribute data.
  • variable: Column name used for filtering and sorting.
  • threshold: Tuple (min, max) for filtering values within a range.
  • filter_method: Sorting method, either 'up' (descending) or 'down' (ascending).
Source code in spatial_tda/adjacency_simplex.py
def __init__(self, geo_dataframe, variable, threshold=None, filter_method='up'):
    """
    Initialize with a GeoDataFrame.

    Parameters:
    - gdf: GeoDataFrame containing geographic and attribute data.
    - variable: Column name used for filtering and sorting.
    - threshold: Tuple (min, max) for filtering values within a range.
    - filter_method: Sorting method, either 'up' (descending) or 'down' (ascending).
    """
    self.gdf = geo_dataframe
    self.variable = variable
    self.filter_method = filter_method
    self.threshold = threshold
    self.filtered_df = None
    self.adjacent_counties_dict = None
    self.merged_df = None
    self.simplicial_complex = None

calculate_adjacent_countries(self)

Compute adjacency relationships between geographic entities.

Source code in spatial_tda/adjacency_simplex.py
def calculate_adjacent_countries(self):
    """
    Compute adjacency relationships between geographic entities.
    """
    # Ensure filter_sort_gdf() has been executed
    if not hasattr(self, 'filtered_df') or not isinstance(self.filtered_df, gpd.GeoDataFrame):
        raise ValueError("Run filter_sort_gdf() before calling this method.")

    # Perform spatial join to find adjacent entities
    adjacent_entities = gpd.sjoin(self.filtered_df, self.filtered_df, predicate='intersects', how='left')

    # Remove self-intersections
    adjacent_entities = adjacent_entities.query('sortedID_left != sortedID_right')

    # Group by entity and store adjacent entities in a list
    adjacent_entities = adjacent_entities.groupby('sortedID_left')['sortedID_right'].apply(list).reset_index()
    adjacent_entities.rename(columns={'sortedID_left': 'county', 'sortedID_right': 'adjacent'}, inplace=True)

    # Create adjacency dictionary
    adjacent_dict = dict(zip(adjacent_entities['county'], adjacent_entities['adjacent']))

    # Merge adjacency information with the original dataset
    merged_df = pd.merge(adjacent_entities, self.filtered_df, left_on='county', right_on='sortedID', how='left')

    # Convert to GeoDataFrame
    merged_df = gpd.GeoDataFrame(merged_df, geometry='geometry')
    merged_df.crs = "EPSG:4326"

    # Store results
    self.adjacent_counties_dict = adjacent_dict
    self.merged_df = merged_df

compute_persistence(self, summaries=None)

Compute persistence diagrams for the simplicial complex and return selected topological summaries.

:param summaries: List of summary names to return (e.g., ["H0", "TL", "AL"]). If None, return all. :return: Dictionary with requested summaries.

Source code in spatial_tda/adjacency_simplex.py
def compute_persistence(self, summaries=None):
    """
    Compute persistence diagrams for the simplicial complex and return selected topological summaries.

    :param summaries: List of summary names to return (e.g., ["H0", "TL", "AL"]). If None, return all.
    :return: Dictionary with requested summaries.
    """

    st = gudhi.SimplexTree()
    st.set_dimension(2)

    for simplex in self.simplicial_complex:
        if len(simplex) == 1:
            st.insert([simplex[0]], filtration=0.0)

    for simplex in self.simplicial_complex:
        if len(simplex) == 2:
            last_simplex = simplex[-1]
            filtration_value = self.filtered_df.loc[
                self.filtered_df['sortedID'] == last_simplex, self.variable
            ].values[0]
            st.insert(simplex, filtration=filtration_value)

    for simplex in self.simplicial_complex:
        if len(simplex) == 3:
            last_simplex = simplex[-1]
            filtration_value = self.filtered_df.loc[
                self.filtered_df['sortedID'] == last_simplex, self.variable
            ].values[0]
            st.insert(simplex, filtration=filtration_value)

    st.compute_persistence()
    persistence = st.persistence()

    intervals_dim0 = st.persistence_intervals_in_dimension(0)

    # Replace infinity with the max variable value
    max_value = self.filtered_df[self.variable].max()
    intervals_dim0[:, 1][np.isinf(intervals_dim0[:, 1])] = max_value

    # Compute topological summaries
    H0_data_points = len(intervals_dim0)
    TL = sum(interval[1] - interval[0] for interval in intervals_dim0)
    TML = sum((interval[1] + interval[0]) / 2 for interval in intervals_dim0)

    AL = TL / len(intervals_dim0) if len(intervals_dim0) > 0 else 0
    AML = TML / len(intervals_dim0) if len(intervals_dim0) > 0 else 0

    # Store results in a dictionary
    results = {
        "H0": H0_data_points,
        "TL": TL,
        "AL": AL,
        "TML": TML,
        "AML": AML,
    }

    # Return only requested summaries
    if summaries:
        return {key: results[key] for key in summaries if key in results}
    return results  # Default: return all summaries

fig2img(fig) staticmethod

Convert a Matplotlib figure to a PIL Image.

  • fig: A Matplotlib figure.
  • A PIL Image.
Source code in spatial_tda/adjacency_simplex.py
@staticmethod
def fig2img(fig):
    """
    Convert a Matplotlib figure to a PIL Image.

    Parameters:
    - fig: A Matplotlib figure.

    Returns:
    - A PIL Image.
    """
    buf = io.BytesIO()
    fig.savefig(buf, bbox_inches='tight', pad_inches=0)
    buf.seek(0)
    img = Image.open(buf)
    return img

filter_sort_gdf(self, return_original=False, return_filtered=False)

Filter and sort the GeoDataFrame based on the specified variable and method.

Source code in spatial_tda/adjacency_simplex.py
def filter_sort_gdf(self,return_original=False,return_filtered=False):
    """
    Filter and sort the GeoDataFrame based on the specified variable and method.
    """
    gdf = self.gdf.copy()

    # Sort the DataFrame based on the specified method
    if self.filter_method == 'up':
        gdf = gdf.sort_values(by=self.variable, ascending=True)
    elif self.filter_method == 'down':
        # get the max value
        max_value = gdf[self.variable].max()
        # invert the values - Assuming negative values are not present
        gdf[self.variable] = max_value - gdf[self.variable]
        gdf = gdf.sort_values(by=self.variable, ascending=True)
    else:
        raise ValueError("Invalid filter method. Use 'up' or 'down'.")

    # this need to be done before filtering
    gdf['sortedID'] = range(len(gdf))

    # this for the below filter
    filtered_df = gdf.copy()

    # Apply threshold filtering if specified
    if self.threshold:
        filtered_df = filtered_df[(filtered_df[self.variable] >= self.threshold[0]) &
                                  (filtered_df[self.variable] <= self.threshold[1])]

    # Convert DataFrame to GeoDataFrame
    filtered_df = gpd.GeoDataFrame(filtered_df, geometry='geometry')

    # Set Coordinate Reference System (CRS)
    filtered_df.crs = "EPSG:4326"

    self.filtered_df = filtered_df

    # this returns a filtered dataframe and the original dataframe with the sortedID
    if return_original and return_filtered:
        return gdf, filtered_df
    elif return_filtered:
        return filtered_df
    elif return_original:
        return gdf

form_simplicial_complex(self, return_simplicial_complex=False)

Construct a simplicial complex using adjacency relationships.

Source code in spatial_tda/adjacency_simplex.py
def form_simplicial_complex(self,return_simplicial_complex=False):
    """
    Construct a simplicial complex using adjacency relationships.
    """
    if not hasattr(self, 'adjacent_counties_dict'):
        raise ValueError("Run calculate_adjacent_countries() before calling this method.")

    max_dimension = 3  # Define maximum dimension for the simplicial complex
    simplicial_complex = invr.incremental_vr([], self.adjacent_counties_dict, max_dimension, list(self.adjacent_counties_dict.keys()))

    self.simplicial_complex = simplicial_complex

    if return_simplicial_complex:
        return simplicial_complex

plot_simplicial_complex(self, save_dir=None)

Plot the simplicial complex and create a GIF animation showing its incremental construction.

For each frame, the base map (self.gdf) is plotted along with labels, then all simplices up to that frame are drawn.

  • save_dir: Directory path to save the GIF. If None, saves in the current directory.
Source code in spatial_tda/adjacency_simplex.py
def plot_simplicial_complex(self, save_dir=None):
    """
    Plot the simplicial complex and create a GIF animation showing its incremental construction.

    For each frame, the base map (self.gdf) is plotted along with labels, then
    all simplices up to that frame are drawn.

    Parameters:
    - save_dir: Directory path to save the GIF. If None, saves in the current directory.
    """
    if self.simplicial_complex is None:
        raise ValueError("Run form_simplicial_complex() before calling this method.")

    # Precompute centroids from filtered_df for plotting edges/triangles.
    city_coordinates = {
        row['sortedID']: np.array((row['geometry'].centroid.x, row['geometry'].centroid.y))
        for _, row in self.filtered_df.iterrows()
    }

    # Create a figure and axis
    fig, ax = plt.subplots(figsize=(10, 10))
    ax.set_axis_off() 

    # Plot the original GeoDataFrame without any filtration
    self.gdf.plot(ax=ax, edgecolor='black', linewidth=0.3, color="white")

    # Plot the centroid of the large square with values
    for _, row in self.gdf.iterrows():
        centroid = row['geometry'].centroid
        text_to_display = f"{row[self.variable]:.2f}"
        plt.text(centroid.x, centroid.y, text_to_display, fontsize=7, ha='center', color="black")


    frames = []
    for edge_or_triangle in self.simplicial_complex:

        # color sub regions based on how it enter the simplcial complex in adjacency method
        if len(edge_or_triangle) == 1:

            vertex = edge_or_triangle[0]
            # geometry = self.filtered_df.iterrows().loc[self.filtered_df.iterrows()['sortedID'] == vertex, 'geometry'].values[0]
            geometry = self.filtered_df[self.filtered_df['sortedID'] == vertex]['geometry'].values[0]
            ax.add_patch(Polygon(np.array(geometry.exterior.coords), closed=True, color='orange', alpha=0.3))
            img = self.fig2img(fig)
            frames.append(img)

        elif len(edge_or_triangle) == 2:
            # Plot an edge
            ax.plot(*zip(*[city_coordinates[vertex] for vertex in edge_or_triangle]), color='red', linewidth=2)
            img = self.fig2img(fig)
            frames.append(img)
        elif len(edge_or_triangle) == 3:
            # Plot a triangle
            ax.add_patch(plt.Polygon([city_coordinates[vertex] for vertex in edge_or_triangle], color='green', alpha=0.2))
            img = self.fig2img(fig)
            frames.append(img)

        #can change above code block
        plt.close(fig)


    # Define the GIF filename.
    gif_filename = f'adj_simplex_{self.variable}_{self.filter_method}.gif'

    if save_dir:
        gif_filename = f'{save_dir}/{gif_filename}'
    # Save the frames as a GIF.
    frames[0].save(gif_filename, save_all=True, append_images=frames[1:],
                   optimize=False, duration=600, loop=0)
    print(f"GIF created and saved as {gif_filename}.")