// Package gpu provides GPU-accelerated k-means clustering for NornicDB.
//
// This file implements ClusterIndex, which extends EmbeddingIndex with
// k-means clustering capabilities for faster semantic search.
//
// Architecture:
//
// ClusterIndex
// ├── EmbeddingIndex (inherited) <- GPU vector search
// ├── centroids [][]float32 <- cluster centers
// ├── assignments []int <- embedding→cluster mapping
// └── clusterMap map[int][]int <- cluster→embeddings lookup
//
// Performance (M1/M2/M3 GPU):
// - 10K embeddings, 100 clusters: ~50-100ms
// - 100K embeddings, 500 clusters: ~500ms-1s
// - Search speedup: 10-50x vs brute-force
//
// Usage:
//
// index := gpu.NewClusterIndex(manager, nil, nil)
// for _, emb := range embeddings {
// index.Add(nodeID, emb)
// }
// index.Cluster() // Run k-means
// results := index.SearchWithClusters(query, 10, 3) // Search 3 clusters
package gpu
import (
"context"
"errors"
"math"
"math/rand"
"sync"
"sync/atomic"
"time"
)
// Errors for k-means clustering
var (
ErrNotClustered = errors.New("gpu: clustering not yet performed")
ErrTooFewEmbeddings = errors.New("gpu: too few embeddings for requested clusters")
ErrInvalidK = errors.New("gpu: invalid number of clusters")
)
// KMeansConfig configures k-means clustering behavior.
//
// Example:
//
// config := &gpu.KMeansConfig{
// NumClusters: 100, // Fixed K
// MaxIterations: 50, // Converge faster
// Tolerance: 0.001, // Stricter convergence
// InitMethod: "kmeans++",
// DriftThreshold: 0.05, // Recluster on 5% drift
// }
type KMeansConfig struct {
// NumClusters is the K value. If 0 and AutoK=true, auto-detected.
NumClusters int
// MaxIterations limits convergence iterations (default: 100)
MaxIterations int
// Tolerance is the convergence threshold (default: 0.0001)
// Clustering stops when centroid drift < tolerance
Tolerance float32
// InitMethod: "kmeans++" (better) or "random" (faster)
InitMethod string
// AutoK enables automatic cluster count selection
AutoK bool
// DriftThreshold triggers re-clustering when centroids drift > this (default: 0.1)
DriftThreshold float32
// MinClusterSize is the minimum embeddings per cluster (default: 10)
MinClusterSize int
}
// DefaultKMeansConfig returns sensible defaults.
// Dimensions are auto-detected from the first embedding added.
func DefaultKMeansConfig() *KMeansConfig {
return &KMeansConfig{
NumClusters: 0, // Auto-detect based on data size
MaxIterations: 100,
Tolerance: 0.0001,
InitMethod: "kmeans++",
AutoK: true,
DriftThreshold: 0.1,
MinClusterSize: 10,
}
}
// ClusterStats holds clustering statistics.
type ClusterStats struct {
EmbeddingCount int
NumClusters int
AvgClusterSize float64
MinClusterSize int
MaxClusterSize int
Iterations int
LastClusterTime time.Duration
CentroidDrift float32
Clustered bool
}
// ClusterIndex extends EmbeddingIndex with k-means clustering.
//
// Architecture:
//
// ┌─────────────────────────────────────────────────────┐
// │ ClusterIndex │
// ├─────────────────────────────────────────────────────┤
// │ EmbeddingIndex (embedded) │
// │ ├── cpuVectors []float32 <- all embeddings │
// │ ├── nodeIDs []string <- ID mapping │
// │ └── GPU buffers <- Metal/CUDA │
// ├─────────────────────────────────────────────────────┤
// │ Clustering State │
// │ ├── centroids [][]float32 <- K cluster centers │
// │ ├── assignments []int <- embedding→cluster │
// │ └── clusterMap map[int][]int <- cluster→indices │
// └─────────────────────────────────────────────────────┘
//
// Usage:
//
// index := gpu.NewClusterIndex(manager, embConfig, kmeansConfig)
//
// // Add embeddings
// for i, emb := range embeddings {
// index.Add(nodeIDs[i], emb)
// }
//
// // Run clustering
// if err := index.Cluster(); err != nil {
// log.Fatal(err)
// }
//
// // Fast cluster-based search
// results, _ := index.SearchWithClusters(query, 10, 3)
//
// Thread Safety: All methods are thread-safe.
type ClusterIndex struct {
*EmbeddingIndex
config *KMeansConfig
// Cluster state
centroids [][]float32 // [K][dimensions] centroid vectors
assignments []int // [N] cluster assignment per embedding
clusterMap map[int][]int // cluster_id -> embedding indices
// Real-time update tracking (Tier 1/2)
pendingUpdates []nodeUpdate
updatesSinceCluster int64
// State tracking
clustered bool
lastClusterTime time.Time
lastClusterDuration time.Duration
iterations int
// Stats
clusterIterations int64
centroidDrift float32
clusterMu sync.RWMutex // Separate mutex for cluster operations
}
// nodeUpdate tracks pending node reassignments for Tier 2 updates.
type nodeUpdate struct {
idx int
oldCluster int
newCluster int
}
// NewClusterIndex creates a clusterable embedding index.
//
// Parameters:
// - manager: GPU manager (can be nil for CPU-only mode)
// - embConfig: Embedding index config (nil uses defaults)
// - kmeansConfig: K-means config (nil uses defaults)
//
// Example:
//
// // Default configuration
// index := gpu.NewClusterIndex(manager, nil, nil)
//
// // Custom configuration
// embConfig := &gpu.EmbeddingIndexConfig{
// Dimensions: 1024,
// InitialCap: 100000,
// }
// kmeansConfig := &gpu.KMeansConfig{
// NumClusters: 500,
// MaxIterations: 50,
// }
// index = gpu.NewClusterIndex(manager, embConfig, kmeansConfig)
func NewClusterIndex(manager *Manager, embConfig *EmbeddingIndexConfig, kmeansConfig *KMeansConfig) *ClusterIndex {
if kmeansConfig == nil {
kmeansConfig = DefaultKMeansConfig()
}
return &ClusterIndex{
EmbeddingIndex: NewEmbeddingIndex(manager, embConfig),
config: kmeansConfig,
clusterMap: make(map[int][]int),
pendingUpdates: make([]nodeUpdate, 0, 1000),
}
}
// Cluster performs k-means clustering on current embeddings.
//
// This method:
// 1. Determines optimal K (if AutoK enabled)
// 2. Initializes centroids (k-means++ or random)
// 3. Iterates assignment/update steps until convergence
// 4. Builds cluster membership map for fast lookup
//
// Returns error if too few embeddings or invalid configuration.
//
// Example:
//
// if err := index.Cluster(); err != nil {
// log.Printf("Clustering failed: %v", err)
// }
//
// stats := index.ClusterStats()
// fmt.Printf("Created %d clusters in %v\n",
// stats.NumClusters, stats.LastClusterTime)
func (ci *ClusterIndex) Cluster() error {
ci.mu.Lock()
ci.clusterMu.Lock()
defer ci.mu.Unlock()
defer ci.clusterMu.Unlock()
n := len(ci.nodeIDs)
if n == 0 {
return nil
}
// Determine K
k := ci.config.NumClusters
if k <= 0 || ci.config.AutoK {
k = optimalK(n)
}
if k > n {
k = n
}
if k < 1 {
return ErrInvalidK
}
if n < k {
return ErrTooFewEmbeddings
}
start := time.Now()
// Initialize centroids
var err error
if ci.config.InitMethod == "kmeans++" {
ci.centroids, err = ci.initCentroidsKMeansPlusPlus(k)
} else {
ci.centroids, err = ci.initCentroidsRandom(k)
}
if err != nil {
return err
}
// Allocate assignments
ci.assignments = make([]int, n)
// Pre-allocate centroid update buffers to avoid allocations in hot loop
dims := ci.dimensions
centroidSums := make([][]float64, k)
centroidCounts := make([]int, k)
for c := 0; c < k; c++ {
centroidSums[c] = make([]float64, dims)
}
// Iterate until convergence
ci.iterations = 0
for iter := 0; iter < ci.config.MaxIterations; iter++ {
// Assignment step
changed := ci.assignToCentroids()
// Update centroids (using pre-allocated buffers)
ci.updateCentroidsWithBuffer(centroidSums, centroidCounts)
ci.iterations++
atomic.AddInt64(&ci.clusterIterations, 1)
// Check convergence
if changed == 0 {
break
}
}
// Build cluster map for fast lookup
ci.buildClusterMap()
ci.clustered = true
ci.lastClusterTime = time.Now()
ci.lastClusterDuration = time.Since(start)
ci.updatesSinceCluster = 0
return nil
}
// optimalK calculates optimal cluster count using sqrt(n/2) heuristic.
func optimalK(n int) int {
k := int(math.Sqrt(float64(n) / 2))
if k < 10 {
k = 10 // Minimum clusters
}
if k > 1000 {
k = 1000 // Maximum clusters
}
return k
}
// initCentroidsRandom initializes centroids by random selection.
func (ci *ClusterIndex) initCentroidsRandom(k int) ([][]float32, error) {
n := len(ci.nodeIDs)
dims := ci.dimensions
centroids := make([][]float32, k)
selected := make(map[int]bool)
for i := 0; i < k; i++ {
// Pick random unselected embedding
var idx int
for {
idx = rand.Intn(n)
if !selected[idx] {
selected[idx] = true
break
}
}
// Copy embedding as centroid
centroids[i] = make([]float32, dims)
start := idx * dims
copy(centroids[i], ci.cpuVectors[start:start+dims])
}
return centroids, nil
}
// initCentroidsKMeansPlusPlus initializes centroids using k-means++ algorithm.
// This produces better initial centroids than random selection.
func (ci *ClusterIndex) initCentroidsKMeansPlusPlus(k int) ([][]float32, error) {
n := len(ci.nodeIDs)
dims := ci.dimensions
centroids := make([][]float32, k)
// Step 1: Choose first centroid randomly
firstIdx := rand.Intn(n)
centroids[0] = make([]float32, dims)
start := firstIdx * dims
copy(centroids[0], ci.cpuVectors[start:start+dims])
// Distance to nearest chosen centroid (cached)
minDistances := make([]float64, n)
// Initialize distances to first centroid
for i := 0; i < n; i++ {
embStart := i * dims
emb := ci.cpuVectors[embStart : embStart+dims]
minDistances[i] = squaredEuclidean(emb, centroids[0])
}
// Step 2: Choose remaining centroids proportional to D(x)^2
for c := 1; c < k; c++ {
// Compute total weight and sample next centroid
totalWeight := 0.0
for i := 0; i < n; i++ {
totalWeight += minDistances[i]
}
// Sample next centroid weighted by distance^2
target := rand.Float64() * totalWeight
cumWeight := 0.0
selectedIdx := n - 1
for i := 0; i < n; i++ {
cumWeight += minDistances[i]
if cumWeight >= target {
selectedIdx = i
break
}
}
// Copy selected embedding as centroid
centroids[c] = make([]float32, dims)
start := selectedIdx * dims
copy(centroids[c], ci.cpuVectors[start:start+dims])
// Update minDistances: only compare with new centroid
// (if new centroid is closer, update cached distance)
newCentroid := centroids[c]
for i := 0; i < n; i++ {
embStart := i * dims
emb := ci.cpuVectors[embStart : embStart+dims]
distToNew := squaredEuclidean(emb, newCentroid)
if distToNew < minDistances[i] {
minDistances[i] = distToNew
}
}
}
return centroids, nil
}
// squaredEuclidean computes squared Euclidean distance.
// Uses 4-way loop unrolling for better instruction-level parallelism.
func squaredEuclidean(a []float32, b []float32) float64 {
n := len(a)
var sum0, sum1, sum2, sum3 float64
// Process 4 elements at a time
i := 0
for ; i <= n-4; i += 4 {
d0 := float64(a[i] - b[i])
d1 := float64(a[i+1] - b[i+1])
d2 := float64(a[i+2] - b[i+2])
d3 := float64(a[i+3] - b[i+3])
sum0 += d0 * d0
sum1 += d1 * d1
sum2 += d2 * d2
sum3 += d3 * d3
}
// Handle remaining elements
for ; i < n; i++ {
diff := float64(a[i] - b[i])
sum0 += diff * diff
}
return sum0 + sum1 + sum2 + sum3
}
// assignToCentroids assigns each embedding to its nearest centroid.
// Returns number of assignments that changed.
func (ci *ClusterIndex) assignToCentroids() int {
n := len(ci.nodeIDs)
dims := ci.dimensions
k := len(ci.centroids)
changed := 0
for i := 0; i < n; i++ {
embStart := i * dims
emb := ci.cpuVectors[embStart : embStart+dims]
// Find nearest centroid
minDist := math.MaxFloat64
nearest := 0
for c := 0; c < k; c++ {
dist := squaredEuclidean(emb, ci.centroids[c])
if dist < minDist {
minDist = dist
nearest = c
}
}
if ci.assignments[i] != nearest {
ci.assignments[i] = nearest
changed++
}
}
return changed
}
// updateCentroids recomputes centroids as mean of assigned embeddings.
func (ci *ClusterIndex) updateCentroids() {
dims := ci.dimensions
k := len(ci.centroids)
n := len(ci.nodeIDs)
// Accumulate sums and counts
sums := make([][]float64, k)
counts := make([]int, k)
for c := 0; c < k; c++ {
sums[c] = make([]float64, dims)
}
for i := 0; i < n; i++ {
cluster := ci.assignments[i]
embStart := i * dims
counts[cluster]++
for d := 0; d < dims; d++ {
sums[cluster][d] += float64(ci.cpuVectors[embStart+d])
}
}
// Compute new centroids (average)
for c := 0; c < k; c++ {
if counts[c] > 0 {
for d := 0; d < dims; d++ {
ci.centroids[c][d] = float32(sums[c][d] / float64(counts[c]))
}
}
// Empty clusters keep their previous position
}
}
// updateCentroidsWithBuffer recomputes centroids using pre-allocated buffers.
// This avoids allocations in the hot clustering loop.
func (ci *ClusterIndex) updateCentroidsWithBuffer(sums [][]float64, counts []int) {
dims := ci.dimensions
k := len(ci.centroids)
n := len(ci.nodeIDs)
// Zero out the buffers
for c := 0; c < k; c++ {
counts[c] = 0
for d := 0; d < dims; d++ {
sums[c][d] = 0
}
}
// Accumulate sums and counts
for i := 0; i < n; i++ {
cluster := ci.assignments[i]
embStart := i * dims
counts[cluster]++
for d := 0; d < dims; d++ {
sums[cluster][d] += float64(ci.cpuVectors[embStart+d])
}
}
// Compute new centroids (average)
for c := 0; c < k; c++ {
if counts[c] > 0 {
for d := 0; d < dims; d++ {
ci.centroids[c][d] = float32(sums[c][d] / float64(counts[c]))
}
}
// Empty clusters keep their previous position
}
}
// buildClusterMap creates the cluster → embedding indices mapping.
func (ci *ClusterIndex) buildClusterMap() {
ci.clusterMap = make(map[int][]int)
for i, cluster := range ci.assignments {
ci.clusterMap[cluster] = append(ci.clusterMap[cluster], i)
}
}
// IsClustered returns true if clustering has been performed.
func (ci *ClusterIndex) IsClustered() bool {
ci.clusterMu.RLock()
defer ci.clusterMu.RUnlock()
return ci.clustered
}
// NumClusters returns the number of clusters.
func (ci *ClusterIndex) NumClusters() int {
ci.clusterMu.RLock()
defer ci.clusterMu.RUnlock()
return len(ci.centroids)
}
// ClusterStats returns clustering statistics.
func (ci *ClusterIndex) ClusterStats() ClusterStats {
ci.clusterMu.RLock()
defer ci.clusterMu.RUnlock()
stats := ClusterStats{
EmbeddingCount: len(ci.nodeIDs),
NumClusters: len(ci.centroids),
Iterations: ci.iterations,
LastClusterTime: ci.lastClusterDuration,
CentroidDrift: ci.centroidDrift,
Clustered: ci.clustered,
}
if len(ci.clusterMap) > 0 {
totalSize := 0
stats.MinClusterSize = math.MaxInt32
stats.MaxClusterSize = 0
for _, members := range ci.clusterMap {
size := len(members)
totalSize += size
if size < stats.MinClusterSize {
stats.MinClusterSize = size
}
if size > stats.MaxClusterSize {
stats.MaxClusterSize = size
}
}
stats.AvgClusterSize = float64(totalSize) / float64(len(ci.clusterMap))
}
return stats
}
// FindNearestCentroid finds the cluster ID nearest to the given embedding.
func (ci *ClusterIndex) FindNearestCentroid(embedding []float32) int {
ci.clusterMu.RLock()
defer ci.clusterMu.RUnlock()
return ci.findNearestCentroidLocked(embedding)
}
// findNearestCentroidLocked finds nearest centroid without acquiring lock.
// Caller must hold clusterMu.
func (ci *ClusterIndex) findNearestCentroidLocked(embedding []float32) int {
if !ci.clustered || len(ci.centroids) == 0 {
return -1
}
minDist := math.MaxFloat64
nearest := 0
for c, centroid := range ci.centroids {
dist := squaredEuclidean(embedding, centroid)
if dist < minDist {
minDist = dist
nearest = c
}
}
return nearest
}
// FindNearestClusters finds the k nearest cluster IDs to the given embedding.
func (ci *ClusterIndex) FindNearestClusters(embedding []float32, k int) []int {
ci.clusterMu.RLock()
defer ci.clusterMu.RUnlock()
if !ci.clustered || len(ci.centroids) == 0 {
return nil
}
if k > len(ci.centroids) {
k = len(ci.centroids)
}
// Compute distances to all centroids
type distIdx struct {
dist float64
idx int
}
distances := make([]distIdx, len(ci.centroids))
for c, centroid := range ci.centroids {
distances[c] = distIdx{
dist: squaredEuclidean(embedding, centroid),
idx: c,
}
}
// Partial sort to find k nearest
for i := 0; i < k; i++ {
minIdx := i
for j := i + 1; j < len(distances); j++ {
if distances[j].dist < distances[minIdx].dist {
minIdx = j
}
}
distances[i], distances[minIdx] = distances[minIdx], distances[i]
}
result := make([]int, k)
for i := 0; i < k; i++ {
result[i] = distances[i].idx
}
return result
}
// GetClusterMembers returns the embedding indices belonging to the given clusters.
func (ci *ClusterIndex) GetClusterMembers(clusterIDs []int) []int {
ci.clusterMu.RLock()
defer ci.clusterMu.RUnlock()
if !ci.clustered {
return nil
}
var members []int
for _, cid := range clusterIDs {
if m, ok := ci.clusterMap[cid]; ok {
members = append(members, m...)
}
}
return members
}
// SearchWithClusters performs cluster-accelerated similarity search.
//
// This method:
// 1. Finds the k nearest clusters to the query
// 2. Gets all embeddings from those clusters as candidates
// 3. Performs exact similarity search on candidates only
//
// Parameters:
// - query: Query embedding vector
// - topK: Number of results to return
// - numClusters: Number of clusters to search (expansion factor)
//
// Returns: SearchResult slice sorted by similarity (descending)
//
// Example:
//
// // Search 3 nearest clusters for top 10 results
// results, err := index.SearchWithClusters(query, 10, 3)
func (ci *ClusterIndex) SearchWithClusters(query []float32, topK, numClusters int) ([]SearchResult, error) {
if !ci.IsClustered() {
// Fall back to brute-force search
return ci.Search(query, topK)
}
// Find nearest clusters
clusterIDs := ci.FindNearestClusters(query, numClusters)
if len(clusterIDs) == 0 {
return nil, nil
}
// Get candidate embedding indices
candidates := ci.GetClusterMembers(clusterIDs)
if len(candidates) == 0 {
return nil, nil
}
// Search among candidates only
return ci.SearchCandidates(context.Background(), query, candidates, topK)
}
// SearchCandidates performs similarity search on a subset of embeddings.
func (ci *ClusterIndex) SearchCandidates(ctx context.Context, query []float32, candidateIndices []int, topK int) ([]SearchResult, error) {
if len(query) != ci.dimensions {
return nil, ErrInvalidDimensions
}
ci.mu.RLock()
defer ci.mu.RUnlock()
if len(candidateIndices) == 0 {
return nil, nil
}
if topK > len(candidateIndices) {
topK = len(candidateIndices)
}
// Compute similarities for candidates only
type scoreIdx struct {
score float32
idx int
}
scores := make([]scoreIdx, len(candidateIndices))
for i, embIdx := range candidateIndices {
start := embIdx * ci.dimensions
end := start + ci.dimensions
emb := ci.cpuVectors[start:end]
scores[i] = scoreIdx{
score: cosineSimilarityFlat(query, emb),
idx: embIdx,
}
}
// Partial sort for top-k
for i := 0; i < topK; i++ {
maxIdx := i
for j := i + 1; j < len(scores); j++ {
if scores[j].score > scores[maxIdx].score {
maxIdx = j
}
}
scores[i], scores[maxIdx] = scores[maxIdx], scores[i]
}
// Build results
results := make([]SearchResult, topK)
for i := 0; i < topK; i++ {
embIdx := scores[i].idx
results[i] = SearchResult{
ID: ci.nodeIDs[embIdx],
Score: scores[i].score,
Distance: 1 - scores[i].score,
}
}
return results, nil
}
// OnNodeUpdate handles real-time embedding changes (Tier 1).
//
// This method:
// 1. Adds/updates the embedding in the index
// 2. If clustered, reassigns to nearest centroid
// 3. Tracks update for potential batch centroid recalculation
//
// Example:
//
// // Called when a node's embedding changes
// if err := index.OnNodeUpdate("node-123", newEmbedding); err != nil {
// log.Printf("Update failed: %v", err)
// }
func (ci *ClusterIndex) OnNodeUpdate(nodeID string, embedding []float32) error {
// Add/update in base index
if err := ci.Add(nodeID, embedding); err != nil {
return err
}
if !ci.IsClustered() {
return nil
}
ci.clusterMu.Lock()
defer ci.clusterMu.Unlock()
ci.mu.RLock()
idx, exists := ci.idToIndex[nodeID]
ci.mu.RUnlock()
if !exists {
return nil
}
// Find nearest centroid (use locked version since we hold clusterMu)
newCluster := ci.findNearestCentroidLocked(embedding)
// Track if assignment changed
if idx < len(ci.assignments) {
oldCluster := ci.assignments[idx]
if newCluster != oldCluster {
// Update cluster membership
ci.removeFromClusterMap(oldCluster, idx)
ci.addToClusterMap(newCluster, idx)
ci.assignments[idx] = newCluster
// Track for batch centroid update
ci.pendingUpdates = append(ci.pendingUpdates, nodeUpdate{
idx: idx,
oldCluster: oldCluster,
newCluster: newCluster,
})
}
} else {
// New embedding, assign to cluster
ci.assignments = append(ci.assignments, newCluster)
ci.addToClusterMap(newCluster, idx)
}
atomic.AddInt64(&ci.updatesSinceCluster, 1)
return nil
}
// removeFromClusterMap removes an embedding index from a cluster's member list.
func (ci *ClusterIndex) removeFromClusterMap(cluster, embIdx int) {
members := ci.clusterMap[cluster]
for i, idx := range members {
if idx == embIdx {
// Remove by swapping with last element
members[i] = members[len(members)-1]
ci.clusterMap[cluster] = members[:len(members)-1]
return
}
}
}
// addToClusterMap adds an embedding index to a cluster's member list.
func (ci *ClusterIndex) addToClusterMap(cluster, embIdx int) {
ci.clusterMap[cluster] = append(ci.clusterMap[cluster], embIdx)
}
// ShouldRecluster checks if re-clustering is needed based on thresholds.
func (ci *ClusterIndex) ShouldRecluster() bool {
ci.clusterMu.RLock()
defer ci.clusterMu.RUnlock()
if !ci.clustered {
return false
}
// Trigger if too many updates (>10% of dataset)
updateRatio := float64(atomic.LoadInt64(&ci.updatesSinceCluster)) / float64(len(ci.nodeIDs))
if updateRatio > 0.1 {
return true
}
// Trigger if centroid drift is high
if ci.centroidDrift > ci.config.DriftThreshold {
return true
}
// Trigger if too much time has passed (1 hour)
if time.Since(ci.lastClusterTime) > time.Hour {
return true
}
return false
}
// UpdateCentroidsBatch recomputes centroids for affected clusters (Tier 2).
// Call periodically to keep centroids accurate after node updates.
func (ci *ClusterIndex) UpdateCentroidsBatch() {
ci.clusterMu.Lock()
updates := ci.pendingUpdates
ci.pendingUpdates = make([]nodeUpdate, 0, 1000)
ci.clusterMu.Unlock()
if len(updates) == 0 {
return
}
// Collect affected clusters
affectedClusters := make(map[int]bool)
for _, u := range updates {
affectedClusters[u.oldCluster] = true
affectedClusters[u.newCluster] = true
}
ci.mu.RLock()
ci.clusterMu.Lock()
defer ci.mu.RUnlock()
defer ci.clusterMu.Unlock()
// Recompute centroids for affected clusters
dims := ci.dimensions
for cluster := range affectedClusters {
members, ok := ci.clusterMap[cluster]
if !ok || len(members) == 0 {
continue
}
// Compute new centroid as mean
newCentroid := make([]float64, dims)
for _, idx := range members {
start := idx * dims
for d := 0; d < dims; d++ {
newCentroid[d] += float64(ci.cpuVectors[start+d])
}
}
for d := 0; d < dims; d++ {
ci.centroids[cluster][d] = float32(newCentroid[d] / float64(len(members)))
}
}
}
// Dimensions returns the embedding dimensions.
func (ci *ClusterIndex) Dimensions() int {
return ci.dimensions
}
// GetConfig returns the k-means configuration.
func (ci *ClusterIndex) GetConfig() *KMeansConfig {
return ci.config
}