diff --git a/packages/memory-service/src/graph/index.js b/packages/memory-service/src/graph/index.js index 32e05f4..f2bce64 100644 --- a/packages/memory-service/src/graph/index.js +++ b/packages/memory-service/src/graph/index.js @@ -64,4 +64,14 @@ function getEntityNeighbors(entityIds) { return { nodes, edges }; } -module.exports = { getNeighborhood, getEntityNeighbors }; +// Returns episode IDs linked to any of the given entity IDs via entity_episodes +function getEpisodeIdsByEntities(entityIds) { + if (!entityIds.length) return []; + const db = getDB(); + const ph = entityIds.map(() => '?').join(','); + return db.prepare( + `SELECT DISTINCT episode_id FROM entity_episodes WHERE entity_id IN (${ph})` + ).all(...entityIds).map(r => r.episode_id); +} + +module.exports = { getNeighborhood, getEntityNeighbors, getEpisodeIdsByEntities }; diff --git a/packages/memory-service/src/index.js b/packages/memory-service/src/index.js index 9b6f054..f572900 100644 --- a/packages/memory-service/src/index.js +++ b/packages/memory-service/src/index.js @@ -251,6 +251,14 @@ app.post('/graph/neighbors', (req, res) => { res.json(graph.getEntityNeighbors(entityIds.map(Number))); }); +app.post('/episodes/by-entities', (req, res) => { + const { entityIds } = req.body; + if (!Array.isArray(entityIds) || entityIds.length === 0) { + return res.status(400).json({ error: 'entityIds array is required' }); + } + res.json({ episodeIds: graph.getEpisodeIdsByEntities(entityIds.map(Number)) }); +}); + /*********************************** */ /********** Project Routes ********** */ /*********************************** */ diff --git a/packages/orchestration-service/src/chat/index.js b/packages/orchestration-service/src/chat/index.js index 145acef..d3096de 100644 --- a/packages/orchestration-service/src/chat/index.js +++ b/packages/orchestration-service/src/chat/index.js @@ -7,7 +7,7 @@ const appSettings = require("../config/settings"); const {triggerSummary} = require('../services/summarization') const graph = require('../services/graph'); -function buildPrompt(recentEpisodes, semanticEpisodes, neighborhood, userMessage, systemPrompt) { +function buildPrompt(guaranteed, selected, neighborhood, userMessage, systemPrompt) { const parts = [systemPrompt ?? ORCHESTRATION.SYSTEM_PROMPT]; const graphText = formatGraphContext(neighborhood.nodes ?? [], neighborhood.edges ?? []); @@ -17,17 +17,17 @@ function buildPrompt(recentEpisodes, semanticEpisodes, neighborhood, userMessage parts.push("---"); } - if (semanticEpisodes.length > 0) { - parts.push("Long-term memory (semantically relevant to this message):"); - for (const ep of semanticEpisodes) { + if (selected.length > 0) { + parts.push("Relevant memories from earlier conversations:"); + for (const ep of selected) { parts.push(`User: ${ep.user_message}\nAssistant: ${ep.ai_response}`); } parts.push("---"); } - if (recentEpisodes.length > 0) { + if (guaranteed.length > 0) { parts.push("Recent conversation history (most recent exchanges):"); - for (const ep of recentEpisodes) { + for (const ep of guaranteed) { parts.push(`User: ${ep.user_message}\nAssistant: ${ep.ai_response}`); } parts.push("--- End of recent memories ---\n"); @@ -152,6 +152,7 @@ async function getFTSResults(userMessage, { limit, sessionIds }) { } } +// Returns {episode, score}[] — scores needed for buildScoredPool downstream function fuseEpisodeResults(semanticEps, keywordEps, { semanticWeight, keywordWeight, limit }) { const k = RETRIEVAL.RRF_K; const scores = new Map(); @@ -164,22 +165,80 @@ function fuseEpisodeResults(semanticEps, keywordEps, { semanticWeight, keywordWe const contrib = keywordWeight / (k + i + 1); if (scores.has(ep.id)) { scores.get(ep.id).score += contrib; - } else { + } else if (contrib > 0) { scores.set(ep.id, { episode: ep, score: contrib }); } }); return [...scores.values()] .sort((a, b) => b.score - a.score) - .slice(0, limit) - .map(({ episode }) => episode); + .slice(0, limit); + } +function estimateTokens(episode) { + return episode.token_count + ?? Math.ceil((episode.user_message.length + episode.ai_response.length) / 4); +} + +function buildScoredPool(fusedWithScores, recentEpisodes, entityBoostedIds, { entityWeight }) { + const k = RETRIEVAL.RRF_K; + const pool = new Map(); // episode.id → {episode, score} + + for (const { episode, score } of fusedWithScores) { + pool.set(episode.id, { episode, score }); + } + + recentEpisodes.forEach((ep, i) => { + const recencyScore = 1.0 / (k + i + 1); + if (pool.has(ep.id)) { + pool.get(ep.id).score += recencyScore; + } else { + pool.set(ep.id, { episode: ep, score: recencyScore }); + } + }); + + for (const id of entityBoostedIds) { + if (pool.has(id)) pool.get(id).score += entityWeight; + } + + return [...pool.values()].sort((a, b) => b.score - a.score); +} + +function selectWithinBudget(scoredPool, contextBudget, minRecentEpisodes, recentEpisodes) { + let budget = contextBudget; + const sortByTime = (a, b) => a.created_at - b.created_at; + + // Guarantee floor: always include the N most recent episodes + const guaranteed = recentEpisodes.slice(0, minRecentEpisodes); + const guaranteedIds = new Set(guaranteed.map(ep => ep.id)); + for (const ep of guaranteed) budget -= estimateTokens(ep); + + // Fill remaining budget from scored pool, highest-priority first + const selected = []; + for (const { episode } of scoredPool) { + if (guaranteedIds.has(episode.id)) continue; + const cost = estimateTokens(episode); + + // // Break rather than skip — lower-priority episodes aren't worth fitting over higher-priority ones + if (budget - cost < 0) break; + selected.push(episode); + budget -= cost; + } + + return { + guaranteed: [...guaranteed].sort(sortByTime), + selected: selected.sort(sortByTime), + }; +} + + async function getFusedEpisodes(userMessage, session, recentIds, projectSessionIds, settings) { const { semanticLimit, scoreThreshold, semanticWeight, keywordWeight } = settings; const ftsSessionIds = projectSessionIds ?? [session.id]; const ftsPromise = keywordWeight > 0 + // FTS and semantic may have significant overlap, so fetching more from FTS gives the fusion step more to work with before deduplication. ? getFTSResults(userMessage, { limit: semanticLimit * 2, sessionIds: ftsSessionIds }) : Promise.resolve([]); @@ -194,8 +253,10 @@ async function getFusedEpisodes(userMessage, session, recentIds, projectSessionI async function assembleContext(externalId, userMessage) { const settings = appSettings.load(); - const { recentEpisodeLimit, semanticLimit, scoreThreshold, - temperature, repeatPenalty, topP, topK, systemPrompt, semanticWeight, keywordWeight } = settings; + const { recentEpisodeLimit, semanticLimit, scoreThreshold, + temperature, repeatPenalty, topP, topK, systemPrompt, + semanticWeight, keywordWeight, + contextBudget, entityWeight, minRecentEpisodes } = settings; // 1. Resolve or create session let session = await memory.getSessionByExternalId(externalId); @@ -222,24 +283,41 @@ async function assembleContext(externalId, userMessage) { const isFirstMessage = recentEpisodes.length === 0; const recentIds = new Set(recentEpisodes.map(e => e.id)); - // 4. Semantic + entity search - const fusedEpisodes = await getFusedEpisodes(userMessage, session, recentIds, projectSessionIds, { semanticLimit, scoreThreshold, semanticWeight, keywordWeight }); - const entityResults = await getRelevantEntities(userMessage, session.project_id ?? null); + // 4. Fused retrieval + entity search in parallel (both are independent) + const [fusedWithScores, entityResults] = await Promise.all([ + getFusedEpisodes(userMessage, session, recentIds, projectSessionIds, { semanticLimit, scoreThreshold, semanticWeight, keywordWeight }), + getRelevantEntities(userMessage, session.project_id ?? null), + ]); - // 5. Expand matched entities into 1-hop graph neighborhood - let neighborhood = { nodes: [], edges: [] }; - if (entityResults.length > 0) { + // 5. Entity-linked episode IDs for scoring bonus + const entityIds = entityResults.map(e => e.id); + let entityBoostedIds = new Set(); + if (entityIds.length > 0) { try { - neighborhood = await graph.getNeighbors(entityResults.map(e => e.id)); + const result = await memory.getEpisodesByEntities(entityIds); + entityBoostedIds = new Set(result.episodeIds); + } catch (err) { + logger.debug('[orchestration] Entity-episode lookup failed, skipping bonus:', err.message); + } + } + + // 6. Build unified scored pool and select within token budget + const scoredPool = buildScoredPool(fusedWithScores, recentEpisodes, entityBoostedIds, { entityWeight }); + const { guaranteed, selected } = selectWithinBudget(scoredPool, contextBudget, minRecentEpisodes, recentEpisodes); + + // 7. Graph neighborhood expansion + let neighborhood = { nodes: [], edges: [] }; + if (entityIds.length > 0) { + try { + neighborhood = await graph.getNeighbors(entityIds); } catch (err) { logger.warn('[orchestration] Graph neighborhood fetch failed, falling back to flat entities:', err.message); - // Graceful fallback: use Qdrant payload data as flat nodes, no edges neighborhood = { nodes: entityResults, edges: [] }; } } - // 6. Assemble prompt - const prompt = buildPrompt(recentEpisodes, fusedEpisodes, neighborhood, userMessage, activeSystemPrompt); + // 8. Assemble prompt + const prompt = buildPrompt(guaranteed, selected, neighborhood, userMessage, activeSystemPrompt); return { session, diff --git a/packages/orchestration-service/src/config/settings.js b/packages/orchestration-service/src/config/settings.js index 157e7ea..5beb18f 100644 --- a/packages/orchestration-service/src/config/settings.js +++ b/packages/orchestration-service/src/config/settings.js @@ -16,6 +16,9 @@ const DEFAULTS = { systemPrompt: ORCHESTRATION.SYSTEM_PROMPT, semanticWeight: RETRIEVAL.SEMANTIC_WEIGHT, keywordWeight: RETRIEVAL.KEYWORD_WEIGHT, + contextBudget: ORCHESTRATION.CONTEXT_BUDGET, + entityWeight: ORCHESTRATION.ENTITY_WEIGHT, + minRecentEpisodes: ORCHESTRATION.MIN_RECENT_EPISODES, }; function load() { diff --git a/packages/orchestration-service/src/routes/settings.js b/packages/orchestration-service/src/routes/settings.js index 773f9c4..2bd6a7c 100644 --- a/packages/orchestration-service/src/routes/settings.js +++ b/packages/orchestration-service/src/routes/settings.js @@ -94,6 +94,27 @@ if (req.body.systemPrompt !== undefined) { updates.keywordWeight = val; } + if (req.body.contextBudget !== undefined) { + const val = Number(req.body.contextBudget); + if (!Number.isInteger(val) || val < 512 || val > 32768) + return res.status(400).json({ error: 'contextBudget must be 512–32768' }); + updates.contextBudget = val; + } + + if (req.body.entityWeight !== undefined) { + const val = Number(req.body.entityWeight); + if (isNaN(val) || val < 0 || val > 2) + return res.status(400).json({ error: 'entityWeight must be 0–2' }); + updates.entityWeight = val; + } + + if (req.body.minRecentEpisodes !== undefined) { + const val = Number(req.body.minRecentEpisodes); + if (!Number.isInteger(val) || val < 0 || val > 10) + return res.status(400).json({ error: 'minRecentEpisodes must be 0–10' }); + updates.minRecentEpisodes = val; + } + res.json(settings.save(updates)); }); diff --git a/packages/orchestration-service/src/services/memory.js b/packages/orchestration-service/src/services/memory.js index 5e48f18..e6317d3 100644 --- a/packages/orchestration-service/src/services/memory.js +++ b/packages/orchestration-service/src/services/memory.js @@ -206,6 +206,16 @@ async function searchEpisodes(query, { limit = 10, sessionIds = null } = {}) { return res.json(); } +async function getEpisodesByEntities(entityIds) { + const res = await fetch(`${BASE_URL}/episodes/by-entities`, { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify({ entityIds }), + }); + if (!res.ok) throw new Error(`Episodes-by-entities error: ${res.status}`); + return res.json(); // { episodeIds: [...] } +} + module.exports = { getSessionByExternalId, createSession, @@ -231,4 +241,5 @@ module.exports = { generateProjectSummary, getProjectOverviewSummary, searchEpisodes, + getEpisodesByEntities, } \ No newline at end of file diff --git a/packages/shared/src/config/constants.js b/packages/shared/src/config/constants.js index b71112d..3a8f9c0 100644 --- a/packages/shared/src/config/constants.js +++ b/packages/shared/src/config/constants.js @@ -28,6 +28,9 @@ const ORCHESTRATION = { ENTITIES_LIMIT: 5, ENTITIES_THRESHOLD: 0.55, TEMPERATURE: 0.7, + CONTEXT_BUDGET: 4096, + ENTITY_WEIGHT: 0.5, + MIN_RECENT_EPISODES: 2, CORS_ORIGIN: 'http://localhost:5173', SYSTEM_PROMPT: `You are a helpful, context-aware AI assistant. You have access to memories of past conversations with the user. Use them to provide consistent, personalised responses.` }