using System.Net.Http.Headers;
using System.Text;
using System.Text.Json;
using Microsoft.Extensions.Logging;
using MarketAlly.AIPlugin.Context.Configuration;
namespace MarketAlly.AIPlugin.Context.Search
{
///
/// Provides semantic search capabilities using OpenAI embeddings
///
public class SemanticSearchEnhancer
{
private readonly ContextConfiguration _configuration;
private readonly ILogger _logger;
private readonly HttpClient _httpClient;
private readonly Dictionary _embeddingCache;
private readonly SemaphoreSlim _rateLimitSemaphore;
public SemanticSearchEnhancer(ContextConfiguration configuration, ILogger logger, HttpClient httpClient)
{
_configuration = configuration;
_logger = logger;
_httpClient = httpClient;
_embeddingCache = new Dictionary();
_rateLimitSemaphore = new SemaphoreSlim(5, 5); // Limit concurrent API calls
ConfigureHttpClient();
}
///
/// Calculates semantic similarity between query and content using embeddings
///
public async Task CalculateSemanticSimilarityAsync(string query, string content, CancellationToken cancellationToken = default)
{
if (string.IsNullOrEmpty(_configuration.Search.OpenAIApiKey))
{
_logger.LogWarning("OpenAI API key not configured, semantic search disabled");
return 0.0;
}
try
{
await _rateLimitSemaphore.WaitAsync(cancellationToken);
var queryEmbedding = await GetEmbeddingAsync(query, cancellationToken);
var contentEmbedding = await GetEmbeddingAsync(content, cancellationToken);
if (queryEmbedding != null && contentEmbedding != null)
{
var similarity = CalculateCosineSimilarity(queryEmbedding, contentEmbedding);
_logger.LogDebug("Calculated semantic similarity: {Similarity} for query length {QueryLength} and content length {ContentLength}",
similarity, query.Length, content.Length);
return similarity;
}
return 0.0;
}
catch (Exception ex)
{
_logger.LogError(ex, "Failed to calculate semantic similarity");
return 0.0;
}
finally
{
_rateLimitSemaphore.Release();
}
}
///
/// Gets embedding for text, with caching to reduce API calls
///
private async Task GetEmbeddingAsync(string text, CancellationToken cancellationToken)
{
// Truncate very long text to avoid API limits
if (text.Length > 8000)
{
text = text.Substring(0, 8000);
}
// Check cache first
var cacheKey = GenerateCacheKey(text);
if (_embeddingCache.TryGetValue(cacheKey, out var cachedEmbedding))
{
return cachedEmbedding;
}
try
{
var requestBody = new
{
input = text,
model = _configuration.Search.OpenAIEmbeddingModel
};
var jsonContent = JsonSerializer.Serialize(requestBody);
var content = new StringContent(jsonContent, Encoding.UTF8, "application/json");
var response = await _httpClient.PostAsync("https://api.openai.com/v1/embeddings", content, cancellationToken);
if (!response.IsSuccessStatusCode)
{
var errorContent = await response.Content.ReadAsStringAsync(cancellationToken);
_logger.LogError("OpenAI API error: {StatusCode} - {Error}", response.StatusCode, errorContent);
return null;
}
var responseJson = await response.Content.ReadAsStringAsync(cancellationToken);
var embeddingResponse = JsonSerializer.Deserialize(responseJson);
if (embeddingResponse?.Data?.Length > 0)
{
var embedding = embeddingResponse.Data[0].Embedding;
// Cache the embedding (with some memory management)
if (_embeddingCache.Count > 1000) // Limit cache size
{
var oldestKey = _embeddingCache.Keys.First();
_embeddingCache.Remove(oldestKey);
}
_embeddingCache[cacheKey] = embedding;
return embedding;
}
return null;
}
catch (HttpRequestException ex)
{
_logger.LogError(ex, "Network error while getting embedding");
return null;
}
catch (JsonException ex)
{
_logger.LogError(ex, "JSON parsing error while processing embedding response");
return null;
}
}
///
/// Calculates cosine similarity between two embedding vectors
///
private double CalculateCosineSimilarity(float[] vectorA, float[] vectorB)
{
if (vectorA.Length != vectorB.Length)
{
_logger.LogWarning("Vector length mismatch: {LengthA} vs {LengthB}", vectorA.Length, vectorB.Length);
return 0.0;
}
double dotProduct = 0.0;
double magnitudeA = 0.0;
double magnitudeB = 0.0;
for (int i = 0; i < vectorA.Length; i++)
{
dotProduct += vectorA[i] * vectorB[i];
magnitudeA += vectorA[i] * vectorA[i];
magnitudeB += vectorB[i] * vectorB[i];
}
magnitudeA = Math.Sqrt(magnitudeA);
magnitudeB = Math.Sqrt(magnitudeB);
if (magnitudeA == 0.0 || magnitudeB == 0.0)
{
return 0.0;
}
var similarity = dotProduct / (magnitudeA * magnitudeB);
// Normalize to 0-1 range (cosine similarity can be -1 to 1)
return (similarity + 1.0) / 2.0;
}
///
/// Generates a cache key for text content
///
private string GenerateCacheKey(string text)
{
// Use a simple hash for caching
return text.Length > 100 ?
$"{text.Substring(0, 50)}_{text.GetHashCode()}_{text.Length}" :
text.GetHashCode().ToString();
}
///
/// Configures the HTTP client for OpenAI API calls
///
private void ConfigureHttpClient()
{
if (!string.IsNullOrEmpty(_configuration.Search.OpenAIApiKey))
{
_httpClient.DefaultRequestHeaders.Authorization =
new AuthenticationHeaderValue("Bearer", _configuration.Search.OpenAIApiKey);
}
_httpClient.DefaultRequestHeaders.Add("User-Agent", "MarketAlly-Context-Plugin/1.0");
_httpClient.Timeout = TimeSpan.FromSeconds(30);
}
///
/// Gets embedding cache statistics
///
public EmbeddingCacheStats GetCacheStats()
{
return new EmbeddingCacheStats
{
CachedEmbeddings = _embeddingCache.Count,
CacheHitRatio = 0.0, // Would need to track hits/misses for this
IsEnabled = !string.IsNullOrEmpty(_configuration.Search.OpenAIApiKey)
};
}
///
/// Clears the embedding cache
///
public void ClearCache()
{
_embeddingCache.Clear();
_logger.LogInformation("Embedding cache cleared");
}
}
///
/// Response structure for OpenAI embeddings API
///
public class OpenAIEmbeddingResponse
{
public OpenAIEmbeddingData[] Data { get; set; } = Array.Empty();
}
public class OpenAIEmbeddingData
{
public float[] Embedding { get; set; } = Array.Empty();
}
///
/// Statistics for embedding cache performance
///
public class EmbeddingCacheStats
{
public int CachedEmbeddings { get; set; }
public double CacheHitRatio { get; set; }
public bool IsEnabled { get; set; }
}
}