artlib.supervised.SimpleARTMAP
Simple ARTMAP [13].
Classes
SimpleARTMAP for Classification. |
Module Contents
- class artlib.supervised.SimpleARTMAP.SimpleARTMAP(module_a: artlib.common.BaseART.BaseART)
Bases:
artlib.common.BaseARTMAP.BaseARTMAPSimpleARTMAP for Classification.
This module implements SimpleARTMAP as first published in: [13].
SimpleARTMAP is a special case of
ARTMAPspecifically for classification. It allows the clustering of data samples while enforcing a many-to-one mapping from sample clusters to labels. It accepts an instantiatedBaseARTmodule and dynamically adapts the vigilance function to prevent resonance when the many-to-one mapping is violated. This enables SimpleARTMAP to identify discrete clusters belonging to each category label.- module_a
- match_reset_func(i: numpy.ndarray, w: numpy.ndarray, cluster_a, params: dict, extra: dict, cache: dict | None = None) bool
Permits external factors to influence cluster creation.
- Parameters:
- Returns:
True if the match is permitted, False otherwise.
- Return type:
- validate_data(X: numpy.ndarray, y: numpy.ndarray) tuple[numpy.ndarray, numpy.ndarray]
Validate data prior to clustering.
- Parameters:
X (np.ndarray) – Data set A.
y (np.ndarray) – Data set B.
- Returns:
The validated datasets X and y.
- Return type:
tuple[np.ndarray, np.ndarray]
- prepare_data(X: numpy.ndarray, y: numpy.ndarray | None = None) numpy.ndarray | Tuple[numpy.ndarray, numpy.ndarray]
Prepare data for clustering.
- Parameters:
X (np.ndarray) – Data set.
y (Optional[np.ndarray]) – Data set B. Not used in SimpleARTMAP
- Returns:
Prepared data.
- Return type:
np.ndarray
- restore_data(X: numpy.ndarray, y: numpy.ndarray | None = None) numpy.ndarray | Tuple[numpy.ndarray, numpy.ndarray]
Restore data to state prior to preparation.
- Parameters:
X (np.ndarray) – Data set.
- Returns:
Restored data.
- Return type:
np.ndarray
- step_fit(x: numpy.ndarray, c_b: int, match_tracking: Literal['MT+', 'MT-', 'MT0', 'MT1', 'MT~'] = 'MT+', epsilon: float = 1e-10) int
Fit the model to a single sample.
- fit(X: numpy.ndarray, y: numpy.ndarray, max_iter=1, match_tracking: Literal['MT+', 'MT-', 'MT0', 'MT1', 'MT~'] = 'MT+', epsilon: float = 1e-10, verbose: bool = False, leave_progress_bar: bool = True)
Fit the model to the data.
- Parameters:
X (np.ndarray) – Data set A.
y (np.ndarray) – Data set B.
max_iter (int, default=1) – Number of iterations to fit the model on the same data set.
match_tracking (Literal, default="MT+") – Method to reset the match.
epsilon (float, default=1e-10) – Small value to adjust the vigilance.
verbose (bool, default=False) – If True, displays a progress bar during training.
leave_progress_bar (bool, default=True) – If True, leaves thge progress of the fitting process. Only used when verbose=True
- Returns:
self – The fitted model.
- Return type:
- fit_predict(X: numpy.ndarray, y: numpy.ndarray, max_iter=1, match_tracking: Literal['MT+', 'MT-', 'MT0', 'MT1', 'MT~'] = 'MT+', epsilon: float = 1e-10, verbose: bool = False, leave_progress_bar: bool = True)
Fit the model to the data and return the labels. Need to define this or ClusterMixin could cause issues.
- Parameters:
X (np.ndarray) – Data set A.
y (np.ndarray) – Data set B.
max_iter (int, default=1) – Number of iterations to fit the model on the same data set.
match_tracking (Literal, default="MT+") – Method to reset the match.
epsilon (float, default=1e-10) – Small value to adjust the vigilance.
verbose (bool, default=False) – If True, displays a progress bar during training.
leave_progress_bar (bool, default=True) – If True, leaves thge progress of the fitting process. Only used when verbose=True
- Returns:
The labels (same as y).
- Return type:
np.ndarray
- fit_gif(X: numpy.ndarray, y: numpy.ndarray, match_tracking: Literal['MT+', 'MT-', 'MT0', 'MT1', 'MT~'] = 'MT+', epsilon: float = 1e-10, verbose: bool = False, leave_progress_bar: bool = True, ax: matplotlib.axes.Axes | None = None, filename: str | None = None, fps: int = 5, final_hold_secs: float = 0.0, colors: artlib.common.utils.IndexableOrKeyable | None = None, n_class_estimate: int | None = None, max_iter: int = 1, **kwargs)
Fit the model while recording the learning process as an animated GIF.
The routine iterates over the training data, calling
step_fit()for each sample, and captures intermediate plots by repeatedly invokingvisualize(). All frames are written to a GIF file (viamatplotlib.animation.PillowWriter), allowing an intuitive, frame‑by‑frame view of how clusters form and adjust over time.- Parameters:
X (np.ndarray) – Independent‑channel samples (side A), shape
(n_samples, n_features).y (np.ndarray) – Target labels (side B), shape
(n_samples,).match_tracking ({"MT+", "MT-", "MT0", "MT1", "MT~"}, default="MT+") – Strategy used by the ART module to reset its match criterion.
epsilon (float, default=1e‑10) – Small positive constant added to the vigilance when
match_trackingtriggers a reset.verbose (bool, default=False) – If
True, displays a tqdm progress bar for each epoch.leave_progress_bar (bool, default=True) – If
True, leaves the progress bar visible after completion (only relevant whenverboseisTrue).ax (matplotlib.axes.Axes, optional) – Existing axes on which to draw frames. If
None, a new figure and axes are created.filename (str, optional) – Output path for the GIF. Defaults to
"fit_gif_supervised_<ClassName>.gif"ifNone.fps (int, default=5) – Frames per second in the resulting GIF.
final_hold_secs (float, default=0.0) – Extra seconds to hold the final frame (duplicates the last plot
ceil(final_hold_secs * fps)times).colors (array‑like, optional) – Sequence of colors to use for each class when plotting. If
None, a rainbow colormap is generated.n_class_estimate (int, optional) – Expected number of distinct classes. Only used when
colorsisNoneto size the autogenerated colormap.max_iter (int, default=1) – Number of complete passes over
(X, y).**kwargs – Additional keyword arguments forwarded to
visualize()(e.g.,marker_size,linewidth).
- Returns:
self – The fitted estimator (identical object, returned for chaining).
- Return type:
Notes
Generates a GIF file as a side‑effect. The estimator itself is updated exactly as in
fit(); only plotting calls and file I/O are added.For reproducible colors across different runs, supply an explicit
colorsarray rather than relying on the rainbow colormap.
- partial_fit(X: numpy.ndarray, y: numpy.ndarray, match_tracking: Literal['MT+', 'MT-', 'MT0', 'MT1', 'MT~'] = 'MT+', epsilon: float = 1e-10)
Partial fit the model to the data.
- Parameters:
X (np.ndarray) – Data set A.
y (np.ndarray) – Data set B.
match_tracking (Literal, default="MT+") – Method to reset the match.
epsilon (float, default=1e-10) – Small value to adjust the vigilance.
- Returns:
self – The partially fitted model.
- Return type:
- property labels_a: numpy.ndarray
Get labels from side A (module A).
- Returns:
Labels from module A.
- Return type:
np.ndarray
- property labels_b: numpy.ndarray
Get labels from side B.
- Returns:
Labels from side B.
- Return type:
np.ndarray
- property labels_ab: Dict[str, numpy.ndarray]
Get labels from both A-side and B-side.
- Returns:
A dictionary with keys “A” and “B” containing labels from sides A and B, respectively.
- Return type:
- property n_clusters: int
Get the number of clusters in side A.
- Returns:
Number of clusters.
- Return type:
- property n_clusters_a: int
Get the number of clusters in side A.
- Returns:
Number of clusters in side A.
- Return type:
- property n_clusters_b: int
Get the number of clusters in side B.
- Returns:
Number of clusters in side B.
- Return type:
- predict(X: numpy.ndarray, clip: bool = False) numpy.ndarray
Predict labels for the data.
- Parameters:
X (np.ndarray) – Data set A.
clip (bool) – clip the input values to be between the previously seen data limits
- Returns:
B labels for the data.
- Return type:
np.ndarray
- predict_ab(X: numpy.ndarray, clip: bool = False) tuple[numpy.ndarray, numpy.ndarray]
Predict labels for the data, both A-side and B-side.
- plot_cluster_bounds(ax: matplotlib.axes.Axes, colors: artlib.common.utils.IndexableOrKeyable, linewidth: int = 1)
Visualize the cluster boundaries.
- Parameters:
ax (Axes) – Figure axes.
colors (IndexableOrKeyable) – Colors to use for each cluster.
linewidth (int, default=1) – Width of boundary lines.