247 lines
7.0 KiB
C#
Executable File
247 lines
7.0 KiB
C#
Executable File
using System.Net.Http.Headers;
|
|
using System.Text;
|
|
using System.Text.Json;
|
|
using Microsoft.Extensions.Logging;
|
|
using MarketAlly.AIPlugin.Context.Configuration;
|
|
|
|
namespace MarketAlly.AIPlugin.Context.Search
|
|
{
|
|
/// <summary>
|
|
/// Provides semantic search capabilities using OpenAI embeddings
|
|
/// </summary>
|
|
public class SemanticSearchEnhancer
|
|
{
|
|
private readonly ContextConfiguration _configuration;
|
|
private readonly ILogger<SemanticSearchEnhancer> _logger;
|
|
private readonly HttpClient _httpClient;
|
|
private readonly Dictionary<string, float[]> _embeddingCache;
|
|
private readonly SemaphoreSlim _rateLimitSemaphore;
|
|
|
|
public SemanticSearchEnhancer(ContextConfiguration configuration, ILogger<SemanticSearchEnhancer> logger, HttpClient httpClient)
|
|
{
|
|
_configuration = configuration;
|
|
_logger = logger;
|
|
_httpClient = httpClient;
|
|
_embeddingCache = new Dictionary<string, float[]>();
|
|
_rateLimitSemaphore = new SemaphoreSlim(5, 5); // Limit concurrent API calls
|
|
|
|
ConfigureHttpClient();
|
|
}
|
|
|
|
/// <summary>
|
|
/// Calculates semantic similarity between query and content using embeddings
|
|
/// </summary>
|
|
public async Task<double> CalculateSemanticSimilarityAsync(string query, string content, CancellationToken cancellationToken = default)
|
|
{
|
|
if (string.IsNullOrEmpty(_configuration.Search.OpenAIApiKey))
|
|
{
|
|
_logger.LogWarning("OpenAI API key not configured, semantic search disabled");
|
|
return 0.0;
|
|
}
|
|
|
|
try
|
|
{
|
|
await _rateLimitSemaphore.WaitAsync(cancellationToken);
|
|
|
|
var queryEmbedding = await GetEmbeddingAsync(query, cancellationToken);
|
|
var contentEmbedding = await GetEmbeddingAsync(content, cancellationToken);
|
|
|
|
if (queryEmbedding != null && contentEmbedding != null)
|
|
{
|
|
var similarity = CalculateCosineSimilarity(queryEmbedding, contentEmbedding);
|
|
_logger.LogDebug("Calculated semantic similarity: {Similarity} for query length {QueryLength} and content length {ContentLength}",
|
|
similarity, query.Length, content.Length);
|
|
return similarity;
|
|
}
|
|
|
|
return 0.0;
|
|
}
|
|
catch (Exception ex)
|
|
{
|
|
_logger.LogError(ex, "Failed to calculate semantic similarity");
|
|
return 0.0;
|
|
}
|
|
finally
|
|
{
|
|
_rateLimitSemaphore.Release();
|
|
}
|
|
}
|
|
|
|
/// <summary>
|
|
/// Gets embedding for text, with caching to reduce API calls
|
|
/// </summary>
|
|
private async Task<float[]?> GetEmbeddingAsync(string text, CancellationToken cancellationToken)
|
|
{
|
|
// Truncate very long text to avoid API limits
|
|
if (text.Length > 8000)
|
|
{
|
|
text = text.Substring(0, 8000);
|
|
}
|
|
|
|
// Check cache first
|
|
var cacheKey = GenerateCacheKey(text);
|
|
if (_embeddingCache.TryGetValue(cacheKey, out var cachedEmbedding))
|
|
{
|
|
return cachedEmbedding;
|
|
}
|
|
|
|
try
|
|
{
|
|
var requestBody = new
|
|
{
|
|
input = text,
|
|
model = _configuration.Search.OpenAIEmbeddingModel
|
|
};
|
|
|
|
var jsonContent = JsonSerializer.Serialize(requestBody);
|
|
var content = new StringContent(jsonContent, Encoding.UTF8, "application/json");
|
|
|
|
var response = await _httpClient.PostAsync("https://api.openai.com/v1/embeddings", content, cancellationToken);
|
|
|
|
if (!response.IsSuccessStatusCode)
|
|
{
|
|
var errorContent = await response.Content.ReadAsStringAsync(cancellationToken);
|
|
_logger.LogError("OpenAI API error: {StatusCode} - {Error}", response.StatusCode, errorContent);
|
|
return null;
|
|
}
|
|
|
|
var responseJson = await response.Content.ReadAsStringAsync(cancellationToken);
|
|
var embeddingResponse = JsonSerializer.Deserialize<OpenAIEmbeddingResponse>(responseJson);
|
|
|
|
if (embeddingResponse?.Data?.Length > 0)
|
|
{
|
|
var embedding = embeddingResponse.Data[0].Embedding;
|
|
|
|
// Cache the embedding (with some memory management)
|
|
if (_embeddingCache.Count > 1000) // Limit cache size
|
|
{
|
|
var oldestKey = _embeddingCache.Keys.First();
|
|
_embeddingCache.Remove(oldestKey);
|
|
}
|
|
|
|
_embeddingCache[cacheKey] = embedding;
|
|
return embedding;
|
|
}
|
|
|
|
return null;
|
|
}
|
|
catch (HttpRequestException ex)
|
|
{
|
|
_logger.LogError(ex, "Network error while getting embedding");
|
|
return null;
|
|
}
|
|
catch (JsonException ex)
|
|
{
|
|
_logger.LogError(ex, "JSON parsing error while processing embedding response");
|
|
return null;
|
|
}
|
|
}
|
|
|
|
/// <summary>
|
|
/// Calculates cosine similarity between two embedding vectors
|
|
/// </summary>
|
|
private double CalculateCosineSimilarity(float[] vectorA, float[] vectorB)
|
|
{
|
|
if (vectorA.Length != vectorB.Length)
|
|
{
|
|
_logger.LogWarning("Vector length mismatch: {LengthA} vs {LengthB}", vectorA.Length, vectorB.Length);
|
|
return 0.0;
|
|
}
|
|
|
|
double dotProduct = 0.0;
|
|
double magnitudeA = 0.0;
|
|
double magnitudeB = 0.0;
|
|
|
|
for (int i = 0; i < vectorA.Length; i++)
|
|
{
|
|
dotProduct += vectorA[i] * vectorB[i];
|
|
magnitudeA += vectorA[i] * vectorA[i];
|
|
magnitudeB += vectorB[i] * vectorB[i];
|
|
}
|
|
|
|
magnitudeA = Math.Sqrt(magnitudeA);
|
|
magnitudeB = Math.Sqrt(magnitudeB);
|
|
|
|
if (magnitudeA == 0.0 || magnitudeB == 0.0)
|
|
{
|
|
return 0.0;
|
|
}
|
|
|
|
var similarity = dotProduct / (magnitudeA * magnitudeB);
|
|
|
|
// Normalize to 0-1 range (cosine similarity can be -1 to 1)
|
|
return (similarity + 1.0) / 2.0;
|
|
}
|
|
|
|
/// <summary>
|
|
/// Generates a cache key for text content
|
|
/// </summary>
|
|
private string GenerateCacheKey(string text)
|
|
{
|
|
// Use a simple hash for caching
|
|
return text.Length > 100 ?
|
|
$"{text.Substring(0, 50)}_{text.GetHashCode()}_{text.Length}" :
|
|
text.GetHashCode().ToString();
|
|
}
|
|
|
|
/// <summary>
|
|
/// Configures the HTTP client for OpenAI API calls
|
|
/// </summary>
|
|
private void ConfigureHttpClient()
|
|
{
|
|
if (!string.IsNullOrEmpty(_configuration.Search.OpenAIApiKey))
|
|
{
|
|
_httpClient.DefaultRequestHeaders.Authorization =
|
|
new AuthenticationHeaderValue("Bearer", _configuration.Search.OpenAIApiKey);
|
|
}
|
|
|
|
_httpClient.DefaultRequestHeaders.Add("User-Agent", "MarketAlly-Context-Plugin/1.0");
|
|
_httpClient.Timeout = TimeSpan.FromSeconds(30);
|
|
}
|
|
|
|
/// <summary>
|
|
/// Gets embedding cache statistics
|
|
/// </summary>
|
|
public EmbeddingCacheStats GetCacheStats()
|
|
{
|
|
return new EmbeddingCacheStats
|
|
{
|
|
CachedEmbeddings = _embeddingCache.Count,
|
|
CacheHitRatio = 0.0, // Would need to track hits/misses for this
|
|
IsEnabled = !string.IsNullOrEmpty(_configuration.Search.OpenAIApiKey)
|
|
};
|
|
}
|
|
|
|
/// <summary>
|
|
/// Clears the embedding cache
|
|
/// </summary>
|
|
public void ClearCache()
|
|
{
|
|
_embeddingCache.Clear();
|
|
_logger.LogInformation("Embedding cache cleared");
|
|
}
|
|
}
|
|
|
|
/// <summary>
|
|
/// Response structure for OpenAI embeddings API
|
|
/// </summary>
|
|
public class OpenAIEmbeddingResponse
|
|
{
|
|
public OpenAIEmbeddingData[] Data { get; set; } = Array.Empty<OpenAIEmbeddingData>();
|
|
}
|
|
|
|
public class OpenAIEmbeddingData
|
|
{
|
|
public float[] Embedding { get; set; } = Array.Empty<float>();
|
|
}
|
|
|
|
/// <summary>
|
|
/// Statistics for embedding cache performance
|
|
/// </summary>
|
|
public class EmbeddingCacheStats
|
|
{
|
|
public int CachedEmbeddings { get; set; }
|
|
public double CacheHitRatio { get; set; }
|
|
public bool IsEnabled { get; set; }
|
|
}
|
|
} |