VectorDB.cs•6.75 kB
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
using System.Numerics.Tensors;
namespace ToolSelection.VectorDb;
public record Entry(string Id, object? Metadata, float[] Vector);
public record QueryResult(float Score, Entry Entry);
public record QueryOptions(int TopK = 10, float MinimumScore = 0.0f, Func<Entry, bool>? Predicate = null);
public interface IDistanceMetric
{
    float Distance(float[] a, float[] b);
    bool BiggerIsCloser { get; }
}
public class CosineSimilarity : IDistanceMetric
{
    public bool BiggerIsCloser => true;  // Cosine similarity: 1 = most similar, -1 = least similar
    public float Distance(float[] a, float[] b)
    {
        if (a.Length != b.Length)
            throw new ArgumentException("Vector lengths must match");
        return TensorPrimitives.CosineSimilarity(a.AsSpan(), b.AsSpan());
    }
}
public class DotProduct : IDistanceMetric
{
    public bool BiggerIsCloser => true;
    public float Distance(float[] a, float[] b)
    {
        if (a.Length != b.Length)
            throw new ArgumentException("Vector lengths must match");
        float dotProduct = 0.0f;
        for (int i = 0; i < a.Length; i++)
        {
            dotProduct += a[i] * b[i];
        }
        return dotProduct;
    }
}
public class VectorDB(IDistanceMetric distanceMetric, IEnumerable<Entry>? entries = null) : IDisposable
{
    private readonly ReaderWriterLockSlim _lock = new();
    private readonly List<Entry> _entries = entries?.OrderBy(e => e.Id, StringComparer.Ordinal).ToList() ?? new();
    private readonly IDistanceMetric _distanceMetric = distanceMetric;
    private bool _disposed = false;
    ~VectorDB()
    {
        Dispose(false);
    }
    public void Dispose()
    {
        Dispose(true);
        GC.SuppressFinalize(this);
    }
    protected virtual void Dispose(bool disposing)
    {
        if (!_disposed)
        {
            if (disposing)
            {
                _lock?.Dispose();
            }
            _disposed = true;
        }
    }
    public int Count
    {
        get
        {
            _lock.EnterReadLock();
            try
            {
                return _entries.Count;
            }
            finally
            {
                _lock.ExitReadLock();
            }
        }
    }
    private int BinarySearch(string id)
    {
        return _entries.BinarySearch(new Entry(id, null, Array.Empty<float>()),
            Comparer<Entry>.Create((a, b) => string.Compare(a.Id, b.Id, StringComparison.Ordinal)));
    }
    public void Upsert(Entry entry)
    {
        _lock.EnterWriteLock();
        try
        {
            int index = BinarySearch(entry.Id);
            if (index >= 0)
            {
                _entries[index] = entry;
            }
            else
            {
                _entries.Insert(~index, entry);
            }
        }
        finally
        {
            _lock.ExitWriteLock();
        }
    }
    public Entry? Get(string id)
    {
        _lock.EnterReadLock();
        try
        {
            int index = BinarySearch(id);
            return index >= 0 ? _entries[index] : null;
        }
        finally
        {
            _lock.ExitReadLock();
        }
    }
    public bool Delete(string id)
    {
        _lock.EnterWriteLock();
        try
        {
            int index = BinarySearch(id);
            if (index >= 0)
            {
                _entries.RemoveAt(index);
                return true;
            }
            return false;
        }
        finally
        {
            _lock.ExitWriteLock();
        }
    }
    public List<QueryResult> Query(float[] vector, QueryOptions options)
    {
        _lock.EnterReadLock();
        try
        {
            return QuerySlice(_entries, vector, options);
        }
        finally
        {
            _lock.ExitReadLock();
        }
    }
    private List<QueryResult> QuerySlice(IEnumerable<Entry> entries, float[] vector, QueryOptions options)
    {
        const int threshold = 100;
        var entryList = entries.ToList();
        if (entryList.Count > threshold)
        {
            int half = entryList.Count / 2;
            var leftTask = Task.Run(() => QuerySlice(entryList.Take(half), vector, options));
            var rightResult = QuerySlice(entryList.Skip(half), vector, options);
            var leftResult = leftTask.Result;
            // Merge results
            var mergedResults = new List<QueryResult>(leftResult.Count + rightResult.Count);
            int leftIndex = 0, rightIndex = 0;
            while (leftIndex < leftResult.Count && rightIndex < rightResult.Count)
            {
                bool takeLeft = _distanceMetric.BiggerIsCloser
                    ? leftResult[leftIndex].Score >= rightResult[rightIndex].Score
                    : leftResult[leftIndex].Score <= rightResult[rightIndex].Score;
                if (takeLeft)
                {
                    mergedResults.Add(leftResult[leftIndex++]);
                }
                else
                {
                    mergedResults.Add(rightResult[rightIndex++]);
                }
            }
            // Add remaining results
            while (leftIndex < leftResult.Count)
                mergedResults.Add(leftResult[leftIndex++]);
            while (rightIndex < rightResult.Count)
                mergedResults.Add(rightResult[rightIndex++]);
            return mergedResults;
        }
        var results = new List<QueryResult>();
        foreach (var entry in entryList)
        {
            if (options.Predicate != null && !options.Predicate(entry))
                continue;
            float score = _distanceMetric.Distance(vector, entry.Vector);
            if (score < options.MinimumScore)
                continue;
            var queryResult = new QueryResult(score, entry);
            // Find insertion point
            int insertIndex = results.BinarySearch(queryResult,
                Comparer<QueryResult>.Create((a, b) =>
                {
                    int result = a.Score.CompareTo(b.Score);
                    return _distanceMetric.BiggerIsCloser ? -result : result;
                }));
            if (insertIndex < 0)
                insertIndex = ~insertIndex;
            if (insertIndex == options.TopK)
            {
                // Score is worse than all current results, skip
                continue;
            }
            if (results.Count == options.TopK)
            {
                // Remove the worst result
                results.RemoveAt(results.Count - 1);
            }
            results.Insert(insertIndex, queryResult);
        }
        return results;
    }
}