smarter context assembly implementation

This commit is contained in:
Storme-bit
2026-04-27 21:41:32 -07:00
parent b58a4e4692
commit e4908193bd
7 changed files with 156 additions and 22 deletions

View File

@@ -64,4 +64,14 @@ function getEntityNeighbors(entityIds) {
return { nodes, edges }; return { nodes, edges };
} }
module.exports = { getNeighborhood, getEntityNeighbors }; // Returns episode IDs linked to any of the given entity IDs via entity_episodes
function getEpisodeIdsByEntities(entityIds) {
if (!entityIds.length) return [];
const db = getDB();
const ph = entityIds.map(() => '?').join(',');
return db.prepare(
`SELECT DISTINCT episode_id FROM entity_episodes WHERE entity_id IN (${ph})`
).all(...entityIds).map(r => r.episode_id);
}
module.exports = { getNeighborhood, getEntityNeighbors, getEpisodeIdsByEntities };

View File

@@ -251,6 +251,14 @@ app.post('/graph/neighbors', (req, res) => {
res.json(graph.getEntityNeighbors(entityIds.map(Number))); res.json(graph.getEntityNeighbors(entityIds.map(Number)));
}); });
app.post('/episodes/by-entities', (req, res) => {
const { entityIds } = req.body;
if (!Array.isArray(entityIds) || entityIds.length === 0) {
return res.status(400).json({ error: 'entityIds array is required' });
}
res.json({ episodeIds: graph.getEpisodeIdsByEntities(entityIds.map(Number)) });
});
/*********************************** */ /*********************************** */
/********** Project Routes ********** */ /********** Project Routes ********** */
/*********************************** */ /*********************************** */

View File

@@ -7,7 +7,7 @@ const appSettings = require("../config/settings");
const {triggerSummary} = require('../services/summarization') const {triggerSummary} = require('../services/summarization')
const graph = require('../services/graph'); const graph = require('../services/graph');
function buildPrompt(recentEpisodes, semanticEpisodes, neighborhood, userMessage, systemPrompt) { function buildPrompt(guaranteed, selected, neighborhood, userMessage, systemPrompt) {
const parts = [systemPrompt ?? ORCHESTRATION.SYSTEM_PROMPT]; const parts = [systemPrompt ?? ORCHESTRATION.SYSTEM_PROMPT];
const graphText = formatGraphContext(neighborhood.nodes ?? [], neighborhood.edges ?? []); const graphText = formatGraphContext(neighborhood.nodes ?? [], neighborhood.edges ?? []);
@@ -17,17 +17,17 @@ function buildPrompt(recentEpisodes, semanticEpisodes, neighborhood, userMessage
parts.push("---"); parts.push("---");
} }
if (semanticEpisodes.length > 0) { if (selected.length > 0) {
parts.push("Long-term memory (semantically relevant to this message):"); parts.push("Relevant memories from earlier conversations:");
for (const ep of semanticEpisodes) { for (const ep of selected) {
parts.push(`User: ${ep.user_message}\nAssistant: ${ep.ai_response}`); parts.push(`User: ${ep.user_message}\nAssistant: ${ep.ai_response}`);
} }
parts.push("---"); parts.push("---");
} }
if (recentEpisodes.length > 0) { if (guaranteed.length > 0) {
parts.push("Recent conversation history (most recent exchanges):"); parts.push("Recent conversation history (most recent exchanges):");
for (const ep of recentEpisodes) { for (const ep of guaranteed) {
parts.push(`User: ${ep.user_message}\nAssistant: ${ep.ai_response}`); parts.push(`User: ${ep.user_message}\nAssistant: ${ep.ai_response}`);
} }
parts.push("--- End of recent memories ---\n"); parts.push("--- End of recent memories ---\n");
@@ -152,6 +152,7 @@ async function getFTSResults(userMessage, { limit, sessionIds }) {
} }
} }
// Returns {episode, score}[] — scores needed for buildScoredPool downstream
function fuseEpisodeResults(semanticEps, keywordEps, { semanticWeight, keywordWeight, limit }) { function fuseEpisodeResults(semanticEps, keywordEps, { semanticWeight, keywordWeight, limit }) {
const k = RETRIEVAL.RRF_K; const k = RETRIEVAL.RRF_K;
const scores = new Map(); const scores = new Map();
@@ -164,22 +165,80 @@ function fuseEpisodeResults(semanticEps, keywordEps, { semanticWeight, keywordWe
const contrib = keywordWeight / (k + i + 1); const contrib = keywordWeight / (k + i + 1);
if (scores.has(ep.id)) { if (scores.has(ep.id)) {
scores.get(ep.id).score += contrib; scores.get(ep.id).score += contrib;
} else { } else if (contrib > 0) {
scores.set(ep.id, { episode: ep, score: contrib }); scores.set(ep.id, { episode: ep, score: contrib });
} }
}); });
return [...scores.values()] return [...scores.values()]
.sort((a, b) => b.score - a.score) .sort((a, b) => b.score - a.score)
.slice(0, limit) .slice(0, limit);
.map(({ episode }) => episode);
} }
function estimateTokens(episode) {
return episode.token_count
?? Math.ceil((episode.user_message.length + episode.ai_response.length) / 4);
}
function buildScoredPool(fusedWithScores, recentEpisodes, entityBoostedIds, { entityWeight }) {
const k = RETRIEVAL.RRF_K;
const pool = new Map(); // episode.id → {episode, score}
for (const { episode, score } of fusedWithScores) {
pool.set(episode.id, { episode, score });
}
recentEpisodes.forEach((ep, i) => {
const recencyScore = 1.0 / (k + i + 1);
if (pool.has(ep.id)) {
pool.get(ep.id).score += recencyScore;
} else {
pool.set(ep.id, { episode: ep, score: recencyScore });
}
});
for (const id of entityBoostedIds) {
if (pool.has(id)) pool.get(id).score += entityWeight;
}
return [...pool.values()].sort((a, b) => b.score - a.score);
}
function selectWithinBudget(scoredPool, contextBudget, minRecentEpisodes, recentEpisodes) {
let budget = contextBudget;
const sortByTime = (a, b) => a.created_at - b.created_at;
// Guarantee floor: always include the N most recent episodes
const guaranteed = recentEpisodes.slice(0, minRecentEpisodes);
const guaranteedIds = new Set(guaranteed.map(ep => ep.id));
for (const ep of guaranteed) budget -= estimateTokens(ep);
// Fill remaining budget from scored pool, highest-priority first
const selected = [];
for (const { episode } of scoredPool) {
if (guaranteedIds.has(episode.id)) continue;
const cost = estimateTokens(episode);
// // Break rather than skip — lower-priority episodes aren't worth fitting over higher-priority ones
if (budget - cost < 0) break;
selected.push(episode);
budget -= cost;
}
return {
guaranteed: [...guaranteed].sort(sortByTime),
selected: selected.sort(sortByTime),
};
}
async function getFusedEpisodes(userMessage, session, recentIds, projectSessionIds, settings) { async function getFusedEpisodes(userMessage, session, recentIds, projectSessionIds, settings) {
const { semanticLimit, scoreThreshold, semanticWeight, keywordWeight } = settings; const { semanticLimit, scoreThreshold, semanticWeight, keywordWeight } = settings;
const ftsSessionIds = projectSessionIds ?? [session.id]; const ftsSessionIds = projectSessionIds ?? [session.id];
const ftsPromise = keywordWeight > 0 const ftsPromise = keywordWeight > 0
// FTS and semantic may have significant overlap, so fetching more from FTS gives the fusion step more to work with before deduplication.
? getFTSResults(userMessage, { limit: semanticLimit * 2, sessionIds: ftsSessionIds }) ? getFTSResults(userMessage, { limit: semanticLimit * 2, sessionIds: ftsSessionIds })
: Promise.resolve([]); : Promise.resolve([]);
@@ -194,8 +253,10 @@ async function getFusedEpisodes(userMessage, session, recentIds, projectSessionI
async function assembleContext(externalId, userMessage) { async function assembleContext(externalId, userMessage) {
const settings = appSettings.load(); const settings = appSettings.load();
const { recentEpisodeLimit, semanticLimit, scoreThreshold, const { recentEpisodeLimit, semanticLimit, scoreThreshold,
temperature, repeatPenalty, topP, topK, systemPrompt, semanticWeight, keywordWeight } = settings; temperature, repeatPenalty, topP, topK, systemPrompt,
semanticWeight, keywordWeight,
contextBudget, entityWeight, minRecentEpisodes } = settings;
// 1. Resolve or create session // 1. Resolve or create session
let session = await memory.getSessionByExternalId(externalId); let session = await memory.getSessionByExternalId(externalId);
@@ -222,24 +283,41 @@ async function assembleContext(externalId, userMessage) {
const isFirstMessage = recentEpisodes.length === 0; const isFirstMessage = recentEpisodes.length === 0;
const recentIds = new Set(recentEpisodes.map(e => e.id)); const recentIds = new Set(recentEpisodes.map(e => e.id));
// 4. Semantic + entity search // 4. Fused retrieval + entity search in parallel (both are independent)
const fusedEpisodes = await getFusedEpisodes(userMessage, session, recentIds, projectSessionIds, { semanticLimit, scoreThreshold, semanticWeight, keywordWeight }); const [fusedWithScores, entityResults] = await Promise.all([
const entityResults = await getRelevantEntities(userMessage, session.project_id ?? null); getFusedEpisodes(userMessage, session, recentIds, projectSessionIds, { semanticLimit, scoreThreshold, semanticWeight, keywordWeight }),
getRelevantEntities(userMessage, session.project_id ?? null),
]);
// 5. Expand matched entities into 1-hop graph neighborhood // 5. Entity-linked episode IDs for scoring bonus
let neighborhood = { nodes: [], edges: [] }; const entityIds = entityResults.map(e => e.id);
if (entityResults.length > 0) { let entityBoostedIds = new Set();
if (entityIds.length > 0) {
try { try {
neighborhood = await graph.getNeighbors(entityResults.map(e => e.id)); const result = await memory.getEpisodesByEntities(entityIds);
entityBoostedIds = new Set(result.episodeIds);
} catch (err) {
logger.debug('[orchestration] Entity-episode lookup failed, skipping bonus:', err.message);
}
}
// 6. Build unified scored pool and select within token budget
const scoredPool = buildScoredPool(fusedWithScores, recentEpisodes, entityBoostedIds, { entityWeight });
const { guaranteed, selected } = selectWithinBudget(scoredPool, contextBudget, minRecentEpisodes, recentEpisodes);
// 7. Graph neighborhood expansion
let neighborhood = { nodes: [], edges: [] };
if (entityIds.length > 0) {
try {
neighborhood = await graph.getNeighbors(entityIds);
} catch (err) { } catch (err) {
logger.warn('[orchestration] Graph neighborhood fetch failed, falling back to flat entities:', err.message); logger.warn('[orchestration] Graph neighborhood fetch failed, falling back to flat entities:', err.message);
// Graceful fallback: use Qdrant payload data as flat nodes, no edges
neighborhood = { nodes: entityResults, edges: [] }; neighborhood = { nodes: entityResults, edges: [] };
} }
} }
// 6. Assemble prompt // 8. Assemble prompt
const prompt = buildPrompt(recentEpisodes, fusedEpisodes, neighborhood, userMessage, activeSystemPrompt); const prompt = buildPrompt(guaranteed, selected, neighborhood, userMessage, activeSystemPrompt);
return { return {
session, session,

View File

@@ -16,6 +16,9 @@ const DEFAULTS = {
systemPrompt: ORCHESTRATION.SYSTEM_PROMPT, systemPrompt: ORCHESTRATION.SYSTEM_PROMPT,
semanticWeight: RETRIEVAL.SEMANTIC_WEIGHT, semanticWeight: RETRIEVAL.SEMANTIC_WEIGHT,
keywordWeight: RETRIEVAL.KEYWORD_WEIGHT, keywordWeight: RETRIEVAL.KEYWORD_WEIGHT,
contextBudget: ORCHESTRATION.CONTEXT_BUDGET,
entityWeight: ORCHESTRATION.ENTITY_WEIGHT,
minRecentEpisodes: ORCHESTRATION.MIN_RECENT_EPISODES,
}; };
function load() { function load() {

View File

@@ -94,6 +94,27 @@ if (req.body.systemPrompt !== undefined) {
updates.keywordWeight = val; updates.keywordWeight = val;
} }
if (req.body.contextBudget !== undefined) {
const val = Number(req.body.contextBudget);
if (!Number.isInteger(val) || val < 512 || val > 32768)
return res.status(400).json({ error: 'contextBudget must be 51232768' });
updates.contextBudget = val;
}
if (req.body.entityWeight !== undefined) {
const val = Number(req.body.entityWeight);
if (isNaN(val) || val < 0 || val > 2)
return res.status(400).json({ error: 'entityWeight must be 02' });
updates.entityWeight = val;
}
if (req.body.minRecentEpisodes !== undefined) {
const val = Number(req.body.minRecentEpisodes);
if (!Number.isInteger(val) || val < 0 || val > 10)
return res.status(400).json({ error: 'minRecentEpisodes must be 010' });
updates.minRecentEpisodes = val;
}
res.json(settings.save(updates)); res.json(settings.save(updates));
}); });

View File

@@ -206,6 +206,16 @@ async function searchEpisodes(query, { limit = 10, sessionIds = null } = {}) {
return res.json(); return res.json();
} }
async function getEpisodesByEntities(entityIds) {
const res = await fetch(`${BASE_URL}/episodes/by-entities`, {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({ entityIds }),
});
if (!res.ok) throw new Error(`Episodes-by-entities error: ${res.status}`);
return res.json(); // { episodeIds: [...] }
}
module.exports = { module.exports = {
getSessionByExternalId, getSessionByExternalId,
createSession, createSession,
@@ -231,4 +241,5 @@ module.exports = {
generateProjectSummary, generateProjectSummary,
getProjectOverviewSummary, getProjectOverviewSummary,
searchEpisodes, searchEpisodes,
getEpisodesByEntities,
} }

View File

@@ -28,6 +28,9 @@ const ORCHESTRATION = {
ENTITIES_LIMIT: 5, ENTITIES_LIMIT: 5,
ENTITIES_THRESHOLD: 0.55, ENTITIES_THRESHOLD: 0.55,
TEMPERATURE: 0.7, TEMPERATURE: 0.7,
CONTEXT_BUDGET: 4096,
ENTITY_WEIGHT: 0.5,
MIN_RECENT_EPISODES: 2,
CORS_ORIGIN: 'http://localhost:5173', CORS_ORIGIN: 'http://localhost:5173',
SYSTEM_PROMPT: `You are a helpful, context-aware AI assistant. You have access to memories of past conversations with the user. Use them to provide consistent, personalised responses.` SYSTEM_PROMPT: `You are a helpful, context-aware AI assistant. You have access to memories of past conversations with the user. Use them to provide consistent, personalised responses.`
} }