smarter context assembly implementation

This commit is contained in:
Storme-bit
2026-04-27 21:41:32 -07:00
parent b58a4e4692
commit e4908193bd
7 changed files with 156 additions and 22 deletions

View File

@@ -64,4 +64,14 @@ function getEntityNeighbors(entityIds) {
return { nodes, edges };
}
module.exports = { getNeighborhood, getEntityNeighbors };
// Returns episode IDs linked to any of the given entity IDs via entity_episodes
function getEpisodeIdsByEntities(entityIds) {
if (!entityIds.length) return [];
const db = getDB();
const ph = entityIds.map(() => '?').join(',');
return db.prepare(
`SELECT DISTINCT episode_id FROM entity_episodes WHERE entity_id IN (${ph})`
).all(...entityIds).map(r => r.episode_id);
}
module.exports = { getNeighborhood, getEntityNeighbors, getEpisodeIdsByEntities };

View File

@@ -251,6 +251,14 @@ app.post('/graph/neighbors', (req, res) => {
res.json(graph.getEntityNeighbors(entityIds.map(Number)));
});
app.post('/episodes/by-entities', (req, res) => {
const { entityIds } = req.body;
if (!Array.isArray(entityIds) || entityIds.length === 0) {
return res.status(400).json({ error: 'entityIds array is required' });
}
res.json({ episodeIds: graph.getEpisodeIdsByEntities(entityIds.map(Number)) });
});
/*********************************** */
/********** Project Routes ********** */
/*********************************** */

View File

@@ -7,7 +7,7 @@ const appSettings = require("../config/settings");
const {triggerSummary} = require('../services/summarization')
const graph = require('../services/graph');
function buildPrompt(recentEpisodes, semanticEpisodes, neighborhood, userMessage, systemPrompt) {
function buildPrompt(guaranteed, selected, neighborhood, userMessage, systemPrompt) {
const parts = [systemPrompt ?? ORCHESTRATION.SYSTEM_PROMPT];
const graphText = formatGraphContext(neighborhood.nodes ?? [], neighborhood.edges ?? []);
@@ -17,17 +17,17 @@ function buildPrompt(recentEpisodes, semanticEpisodes, neighborhood, userMessage
parts.push("---");
}
if (semanticEpisodes.length > 0) {
parts.push("Long-term memory (semantically relevant to this message):");
for (const ep of semanticEpisodes) {
if (selected.length > 0) {
parts.push("Relevant memories from earlier conversations:");
for (const ep of selected) {
parts.push(`User: ${ep.user_message}\nAssistant: ${ep.ai_response}`);
}
parts.push("---");
}
if (recentEpisodes.length > 0) {
if (guaranteed.length > 0) {
parts.push("Recent conversation history (most recent exchanges):");
for (const ep of recentEpisodes) {
for (const ep of guaranteed) {
parts.push(`User: ${ep.user_message}\nAssistant: ${ep.ai_response}`);
}
parts.push("--- End of recent memories ---\n");
@@ -152,6 +152,7 @@ async function getFTSResults(userMessage, { limit, sessionIds }) {
}
}
// Returns {episode, score}[] — scores needed for buildScoredPool downstream
function fuseEpisodeResults(semanticEps, keywordEps, { semanticWeight, keywordWeight, limit }) {
const k = RETRIEVAL.RRF_K;
const scores = new Map();
@@ -164,22 +165,80 @@ function fuseEpisodeResults(semanticEps, keywordEps, { semanticWeight, keywordWe
const contrib = keywordWeight / (k + i + 1);
if (scores.has(ep.id)) {
scores.get(ep.id).score += contrib;
} else {
} else if (contrib > 0) {
scores.set(ep.id, { episode: ep, score: contrib });
}
});
return [...scores.values()]
.sort((a, b) => b.score - a.score)
.slice(0, limit)
.map(({ episode }) => episode);
.slice(0, limit);
}
function estimateTokens(episode) {
return episode.token_count
?? Math.ceil((episode.user_message.length + episode.ai_response.length) / 4);
}
function buildScoredPool(fusedWithScores, recentEpisodes, entityBoostedIds, { entityWeight }) {
const k = RETRIEVAL.RRF_K;
const pool = new Map(); // episode.id → {episode, score}
for (const { episode, score } of fusedWithScores) {
pool.set(episode.id, { episode, score });
}
recentEpisodes.forEach((ep, i) => {
const recencyScore = 1.0 / (k + i + 1);
if (pool.has(ep.id)) {
pool.get(ep.id).score += recencyScore;
} else {
pool.set(ep.id, { episode: ep, score: recencyScore });
}
});
for (const id of entityBoostedIds) {
if (pool.has(id)) pool.get(id).score += entityWeight;
}
return [...pool.values()].sort((a, b) => b.score - a.score);
}
function selectWithinBudget(scoredPool, contextBudget, minRecentEpisodes, recentEpisodes) {
let budget = contextBudget;
const sortByTime = (a, b) => a.created_at - b.created_at;
// Guarantee floor: always include the N most recent episodes
const guaranteed = recentEpisodes.slice(0, minRecentEpisodes);
const guaranteedIds = new Set(guaranteed.map(ep => ep.id));
for (const ep of guaranteed) budget -= estimateTokens(ep);
// Fill remaining budget from scored pool, highest-priority first
const selected = [];
for (const { episode } of scoredPool) {
if (guaranteedIds.has(episode.id)) continue;
const cost = estimateTokens(episode);
// // Break rather than skip — lower-priority episodes aren't worth fitting over higher-priority ones
if (budget - cost < 0) break;
selected.push(episode);
budget -= cost;
}
return {
guaranteed: [...guaranteed].sort(sortByTime),
selected: selected.sort(sortByTime),
};
}
async function getFusedEpisodes(userMessage, session, recentIds, projectSessionIds, settings) {
const { semanticLimit, scoreThreshold, semanticWeight, keywordWeight } = settings;
const ftsSessionIds = projectSessionIds ?? [session.id];
const ftsPromise = keywordWeight > 0
// FTS and semantic may have significant overlap, so fetching more from FTS gives the fusion step more to work with before deduplication.
? getFTSResults(userMessage, { limit: semanticLimit * 2, sessionIds: ftsSessionIds })
: Promise.resolve([]);
@@ -195,7 +254,9 @@ async function getFusedEpisodes(userMessage, session, recentIds, projectSessionI
async function assembleContext(externalId, userMessage) {
const settings = appSettings.load();
const { recentEpisodeLimit, semanticLimit, scoreThreshold,
temperature, repeatPenalty, topP, topK, systemPrompt, semanticWeight, keywordWeight } = settings;
temperature, repeatPenalty, topP, topK, systemPrompt,
semanticWeight, keywordWeight,
contextBudget, entityWeight, minRecentEpisodes } = settings;
// 1. Resolve or create session
let session = await memory.getSessionByExternalId(externalId);
@@ -222,24 +283,41 @@ async function assembleContext(externalId, userMessage) {
const isFirstMessage = recentEpisodes.length === 0;
const recentIds = new Set(recentEpisodes.map(e => e.id));
// 4. Semantic + entity search
const fusedEpisodes = await getFusedEpisodes(userMessage, session, recentIds, projectSessionIds, { semanticLimit, scoreThreshold, semanticWeight, keywordWeight });
const entityResults = await getRelevantEntities(userMessage, session.project_id ?? null);
// 4. Fused retrieval + entity search in parallel (both are independent)
const [fusedWithScores, entityResults] = await Promise.all([
getFusedEpisodes(userMessage, session, recentIds, projectSessionIds, { semanticLimit, scoreThreshold, semanticWeight, keywordWeight }),
getRelevantEntities(userMessage, session.project_id ?? null),
]);
// 5. Expand matched entities into 1-hop graph neighborhood
let neighborhood = { nodes: [], edges: [] };
if (entityResults.length > 0) {
// 5. Entity-linked episode IDs for scoring bonus
const entityIds = entityResults.map(e => e.id);
let entityBoostedIds = new Set();
if (entityIds.length > 0) {
try {
neighborhood = await graph.getNeighbors(entityResults.map(e => e.id));
const result = await memory.getEpisodesByEntities(entityIds);
entityBoostedIds = new Set(result.episodeIds);
} catch (err) {
logger.debug('[orchestration] Entity-episode lookup failed, skipping bonus:', err.message);
}
}
// 6. Build unified scored pool and select within token budget
const scoredPool = buildScoredPool(fusedWithScores, recentEpisodes, entityBoostedIds, { entityWeight });
const { guaranteed, selected } = selectWithinBudget(scoredPool, contextBudget, minRecentEpisodes, recentEpisodes);
// 7. Graph neighborhood expansion
let neighborhood = { nodes: [], edges: [] };
if (entityIds.length > 0) {
try {
neighborhood = await graph.getNeighbors(entityIds);
} catch (err) {
logger.warn('[orchestration] Graph neighborhood fetch failed, falling back to flat entities:', err.message);
// Graceful fallback: use Qdrant payload data as flat nodes, no edges
neighborhood = { nodes: entityResults, edges: [] };
}
}
// 6. Assemble prompt
const prompt = buildPrompt(recentEpisodes, fusedEpisodes, neighborhood, userMessage, activeSystemPrompt);
// 8. Assemble prompt
const prompt = buildPrompt(guaranteed, selected, neighborhood, userMessage, activeSystemPrompt);
return {
session,

View File

@@ -16,6 +16,9 @@ const DEFAULTS = {
systemPrompt: ORCHESTRATION.SYSTEM_PROMPT,
semanticWeight: RETRIEVAL.SEMANTIC_WEIGHT,
keywordWeight: RETRIEVAL.KEYWORD_WEIGHT,
contextBudget: ORCHESTRATION.CONTEXT_BUDGET,
entityWeight: ORCHESTRATION.ENTITY_WEIGHT,
minRecentEpisodes: ORCHESTRATION.MIN_RECENT_EPISODES,
};
function load() {

View File

@@ -94,6 +94,27 @@ if (req.body.systemPrompt !== undefined) {
updates.keywordWeight = val;
}
if (req.body.contextBudget !== undefined) {
const val = Number(req.body.contextBudget);
if (!Number.isInteger(val) || val < 512 || val > 32768)
return res.status(400).json({ error: 'contextBudget must be 51232768' });
updates.contextBudget = val;
}
if (req.body.entityWeight !== undefined) {
const val = Number(req.body.entityWeight);
if (isNaN(val) || val < 0 || val > 2)
return res.status(400).json({ error: 'entityWeight must be 02' });
updates.entityWeight = val;
}
if (req.body.minRecentEpisodes !== undefined) {
const val = Number(req.body.minRecentEpisodes);
if (!Number.isInteger(val) || val < 0 || val > 10)
return res.status(400).json({ error: 'minRecentEpisodes must be 010' });
updates.minRecentEpisodes = val;
}
res.json(settings.save(updates));
});

View File

@@ -206,6 +206,16 @@ async function searchEpisodes(query, { limit = 10, sessionIds = null } = {}) {
return res.json();
}
async function getEpisodesByEntities(entityIds) {
const res = await fetch(`${BASE_URL}/episodes/by-entities`, {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({ entityIds }),
});
if (!res.ok) throw new Error(`Episodes-by-entities error: ${res.status}`);
return res.json(); // { episodeIds: [...] }
}
module.exports = {
getSessionByExternalId,
createSession,
@@ -231,4 +241,5 @@ module.exports = {
generateProjectSummary,
getProjectOverviewSummary,
searchEpisodes,
getEpisodesByEntities,
}

View File

@@ -28,6 +28,9 @@ const ORCHESTRATION = {
ENTITIES_LIMIT: 5,
ENTITIES_THRESHOLD: 0.55,
TEMPERATURE: 0.7,
CONTEXT_BUDGET: 4096,
ENTITY_WEIGHT: 0.5,
MIN_RECENT_EPISODES: 2,
CORS_ORIGIN: 'http://localhost:5173',
SYSTEM_PROMPT: `You are a helpful, context-aware AI assistant. You have access to memories of past conversations with the user. Use them to provide consistent, personalised responses.`
}