Source code for kluster_fudge.gpu

from __future__ import annotations
import numpy as np
import numpy.typing as npt
import warnings

try:
    import torch
except ImportError:
    torch = None

from kluster_fudge.core import KModes


[docs] class KModesGPU(KModes): def __init__( self, n_clusters: int = 8, n_init: int = 10, max_iter: int = 100, init_method: str = "cao", dist_metric: str = "hamming", random_state: int = 42, device: str | None = None, ) -> None: super().__init__( n_clusters=n_clusters, n_init=n_init, max_iter=max_iter, init_method=init_method, dist_metric=dist_metric, random_state=random_state, ) if torch is None: raise ImportError( "PyTorch is required for KModesGPU. Please install it with `pip install torch`." ) self.device = device if self.device is None: if torch.cuda.is_available(): self.device = "cuda" elif torch.backends.mps.is_available(): self.device = "mps" else: self.device = "cpu" warnings.warn( "No GPU detected. KModesGPU will run on CPU, which might be slower than KModes (Numba)." )
[docs] def fit(self, X: npt.ArrayLike) -> None: """ Fit the model to the input data using GPU acceleration. Args: X: (npt.ArrayLike) Input data, array-like Returns: None """ # check if X is a pandas dataframe if hasattr(X, "values"): X = X.values self.is_df = True else: X = np.asarray(X) # cpu enc X_encoded_cpu = self._encode(X) X_gpu = torch.from_numpy(X_encoded_cpu).to(self.device) # check int type n_samples, n_features = X_gpu.shape best_cost = float("inf") best_centroids = None best_labels = None if self.n_init < 1: raise ValueError(f"n_init must be at least 1, got {self.n_init}") if self.random_state is not None: torch.manual_seed(self.random_state) for init_idx in range(self.n_init): current_seed = ( self.random_state + init_idx if self.random_state is not None else None ) # Oslice the numpy array for initialization on CPU from kluster_fudge.init import init_centroids centroids_cpu = init_centroids( X_encoded_cpu, self.n_clusters, self.init_method, random_state=current_seed, ) centroids = torch.from_numpy(centroids_cpu).to(self.device) labels = torch.zeros(n_samples, dtype=torch.long, device=self.device) for i in range(self.max_iter): # dist calc (expand dims) # normalize metric metric_str = self.dist_metric if hasattr(metric_str, "value"): metric_str = metric_str.value if metric_str == "hamming": dists = self._hamming(X_gpu, centroids) elif metric_str == "jaccard": dists = self._jaccard(X_gpu, centroids) elif metric_str == "ng": if i == 0: # 1st iter: hamming dists = self._hamming(X_gpu, centroids) else: dists = self._ng(X_gpu, centroids, labels, n_features) else: # fallback raise ValueError(f"Unsupported metric: {self.dist_metric}") # assign lbls new_labels = torch.argmin(dists, dim=1) # converged? if torch.equal(labels, new_labels) and i > 0: break labels = new_labels # update centroids (vec w/ bincount) max_val = int(X_gpu.max().item()) if max_val < 0: max_val = 0 val_offset = max_val + 1 counts = self._compute_counts( X_gpu, labels, self.n_clusters, n_features, val_offset ) # reshape (F, K, V) counts_reshaped = counts.view(n_features, self.n_clusters, val_offset) # mode via argmax new_centroids_t = counts_reshaped.argmax(dim=2) new_centroids = new_centroids_t.t() # (K, F) # handle empty cluster_counts = torch.bincount(labels, minlength=self.n_clusters) empty_clusters = cluster_counts == 0 if empty_clusters.any(): new_centroids[empty_clusters] = centroids[empty_clusters] centroids = new_centroids # final cost metric_str = self.dist_metric if hasattr(metric_str, "value"): metric_str = metric_str.value if metric_str == "hamming": final_dists = self._hamming(X_gpu, centroids) elif metric_str == "jaccard": final_dists = self._jaccard(X_gpu, centroids) elif metric_str == "ng": final_dists = self._ng(X_gpu, centroids, labels, n_features) else: final_dists = self._hamming(X_gpu, centroids) row_idx = torch.arange(n_samples, device=self.device) min_dists = final_dists[row_idx, labels] cost = min_dists.sum().item() if cost < best_cost: best_cost = cost best_centroids = centroids.clone() best_labels = labels.clone() self.centroids = best_centroids.cpu().numpy() self.labels = best_labels.cpu().numpy() self.cost_ = best_cost self.decoded_centroids = self._decode(self.centroids)
def _compute_counts( self, X: torch.Tensor, labels: torch.Tensor, n_clusters: int, n_features: int, val_offset: int, ) -> torch.Tensor: """ freq counts (k, f, v) Args: X: (torch.Tensor) Input data (n_samples, n_features) labels: (torch.Tensor) Cluster labels (n_samples) n_clusters: (int) Number of clusters n_features: (int) Number of features val_offset: (int) Value offset for flat indexing Returns: (torch.Tensor) Flattened counts (n_features * n_clusters * val_offset) """ n_samples = X.shape[0] # labels: (N) -> (N, F) labels_expanded = labels.unsqueeze(1).expand(-1, n_features) # feature indices feature_indices = ( torch.arange(n_features, device=self.device) .unsqueeze(0) .expand(n_samples, -1) ) # flat idx flat_indices = ( feature_indices * (n_clusters * val_offset) + labels_expanded * val_offset + X ).view(-1) num_bins = n_features * n_clusters * val_offset return torch.bincount(flat_indices, minlength=num_bins) def _hamming(self, X: torch.Tensor, centroids: torch.Tensor) -> torch.Tensor: """ hamming dist Args: X: (torch.Tensor) Input data (n_samples, n_features) centroids: (torch.Tensor) Centroids (n_clusters, n_features) Returns: (torch.Tensor) Distance matrix (n_samples, n_clusters) """ # (N, K) return (X.unsqueeze(1) != centroids.unsqueeze(0)).sum(dim=2).float() def _jaccard(self, X: torch.Tensor, centroids: torch.Tensor) -> torch.Tensor: """ jaccard dist Args: X: (torch.Tensor) Input data (n_samples, n_features) centroids: (torch.Tensor) Centroids (n_clusters, n_features) Returns: (torch.Tensor) Distance matrix (n_samples, n_clusters) """ n_features = X.shape[1] hamming = self._hamming(X, centroids) intersection = n_features - hamming union = 2 * n_features - intersection # no div0 return 1.0 - (intersection / union) def _ng( self, X: torch.Tensor, centroids: torch.Tensor, labels: torch.Tensor, n_features: int, ) -> torch.Tensor: """ ng dist (freq based) Args: X: (torch.Tensor) Input data (n_samples, n_features) centroids: (torch.Tensor) Centroids (n_clusters, n_features) labels: (torch.Tensor) Cluster labels (n_samples) n_features: (int) Number of features Returns: (torch.Tensor) Distance matrix (n_samples, n_clusters) """ n_samples = X.shape[0] n_clusters = centroids.shape[0] # 1. counts & sizes max_val = int(X.max().item()) if max_val < 0: max_val = 0 val_offset = max_val + 1 counts_flat = self._compute_counts( X, labels, n_clusters, n_features, val_offset ) counts = counts_flat.view(n_features, n_clusters, val_offset) cluster_sizes = torch.bincount(labels, minlength=n_clusters).float() cluster_sizes[cluster_sizes == 0] = 1.0 # probs: (F, K, V) probs = counts.float() / cluster_sizes.view(1, n_clusters, 1) # vec lookup (no loops) # rearrange probs probs_flat = ( probs.permute(0, 2, 1) .contiguous() .view(n_features * val_offset, n_clusters) ) # calc idx # feature_offsets: (1, F) feature_offsets = ( torch.arange(n_features, device=self.device) * val_offset ).unsqueeze(0) # indices: (N, F) -> fat (N*F) lookup_indices = (X + feature_offsets).view(-1).long() # gather probs gathered_probs = probs_flat.index_select(0, lookup_indices) # reshape & sum gathered_probs = gathered_probs.view(n_samples, n_features, n_clusters) sum_probs = gathered_probs.sum(dim=1) # (N, K) dists = float(n_features) - sum_probs return dists