smarter context assembly implementation
This commit is contained in:
@@ -64,4 +64,14 @@ function getEntityNeighbors(entityIds) {
|
||||
return { nodes, edges };
|
||||
}
|
||||
|
||||
module.exports = { getNeighborhood, getEntityNeighbors };
|
||||
// Returns episode IDs linked to any of the given entity IDs via entity_episodes
|
||||
function getEpisodeIdsByEntities(entityIds) {
|
||||
if (!entityIds.length) return [];
|
||||
const db = getDB();
|
||||
const ph = entityIds.map(() => '?').join(',');
|
||||
return db.prepare(
|
||||
`SELECT DISTINCT episode_id FROM entity_episodes WHERE entity_id IN (${ph})`
|
||||
).all(...entityIds).map(r => r.episode_id);
|
||||
}
|
||||
|
||||
module.exports = { getNeighborhood, getEntityNeighbors, getEpisodeIdsByEntities };
|
||||
|
||||
@@ -251,6 +251,14 @@ app.post('/graph/neighbors', (req, res) => {
|
||||
res.json(graph.getEntityNeighbors(entityIds.map(Number)));
|
||||
});
|
||||
|
||||
app.post('/episodes/by-entities', (req, res) => {
|
||||
const { entityIds } = req.body;
|
||||
if (!Array.isArray(entityIds) || entityIds.length === 0) {
|
||||
return res.status(400).json({ error: 'entityIds array is required' });
|
||||
}
|
||||
res.json({ episodeIds: graph.getEpisodeIdsByEntities(entityIds.map(Number)) });
|
||||
});
|
||||
|
||||
/*********************************** */
|
||||
/********** Project Routes ********** */
|
||||
/*********************************** */
|
||||
|
||||
@@ -7,7 +7,7 @@ const appSettings = require("../config/settings");
|
||||
const {triggerSummary} = require('../services/summarization')
|
||||
const graph = require('../services/graph');
|
||||
|
||||
function buildPrompt(recentEpisodes, semanticEpisodes, neighborhood, userMessage, systemPrompt) {
|
||||
function buildPrompt(guaranteed, selected, neighborhood, userMessage, systemPrompt) {
|
||||
const parts = [systemPrompt ?? ORCHESTRATION.SYSTEM_PROMPT];
|
||||
|
||||
const graphText = formatGraphContext(neighborhood.nodes ?? [], neighborhood.edges ?? []);
|
||||
@@ -17,17 +17,17 @@ function buildPrompt(recentEpisodes, semanticEpisodes, neighborhood, userMessage
|
||||
parts.push("---");
|
||||
}
|
||||
|
||||
if (semanticEpisodes.length > 0) {
|
||||
parts.push("Long-term memory (semantically relevant to this message):");
|
||||
for (const ep of semanticEpisodes) {
|
||||
if (selected.length > 0) {
|
||||
parts.push("Relevant memories from earlier conversations:");
|
||||
for (const ep of selected) {
|
||||
parts.push(`User: ${ep.user_message}\nAssistant: ${ep.ai_response}`);
|
||||
}
|
||||
parts.push("---");
|
||||
}
|
||||
|
||||
if (recentEpisodes.length > 0) {
|
||||
if (guaranteed.length > 0) {
|
||||
parts.push("Recent conversation history (most recent exchanges):");
|
||||
for (const ep of recentEpisodes) {
|
||||
for (const ep of guaranteed) {
|
||||
parts.push(`User: ${ep.user_message}\nAssistant: ${ep.ai_response}`);
|
||||
}
|
||||
parts.push("--- End of recent memories ---\n");
|
||||
@@ -152,6 +152,7 @@ async function getFTSResults(userMessage, { limit, sessionIds }) {
|
||||
}
|
||||
}
|
||||
|
||||
// Returns {episode, score}[] — scores needed for buildScoredPool downstream
|
||||
function fuseEpisodeResults(semanticEps, keywordEps, { semanticWeight, keywordWeight, limit }) {
|
||||
const k = RETRIEVAL.RRF_K;
|
||||
const scores = new Map();
|
||||
@@ -164,22 +165,80 @@ function fuseEpisodeResults(semanticEps, keywordEps, { semanticWeight, keywordWe
|
||||
const contrib = keywordWeight / (k + i + 1);
|
||||
if (scores.has(ep.id)) {
|
||||
scores.get(ep.id).score += contrib;
|
||||
} else {
|
||||
} else if (contrib > 0) {
|
||||
scores.set(ep.id, { episode: ep, score: contrib });
|
||||
}
|
||||
});
|
||||
|
||||
return [...scores.values()]
|
||||
.sort((a, b) => b.score - a.score)
|
||||
.slice(0, limit)
|
||||
.map(({ episode }) => episode);
|
||||
.slice(0, limit);
|
||||
|
||||
}
|
||||
|
||||
function estimateTokens(episode) {
|
||||
return episode.token_count
|
||||
?? Math.ceil((episode.user_message.length + episode.ai_response.length) / 4);
|
||||
}
|
||||
|
||||
function buildScoredPool(fusedWithScores, recentEpisodes, entityBoostedIds, { entityWeight }) {
|
||||
const k = RETRIEVAL.RRF_K;
|
||||
const pool = new Map(); // episode.id → {episode, score}
|
||||
|
||||
for (const { episode, score } of fusedWithScores) {
|
||||
pool.set(episode.id, { episode, score });
|
||||
}
|
||||
|
||||
recentEpisodes.forEach((ep, i) => {
|
||||
const recencyScore = 1.0 / (k + i + 1);
|
||||
if (pool.has(ep.id)) {
|
||||
pool.get(ep.id).score += recencyScore;
|
||||
} else {
|
||||
pool.set(ep.id, { episode: ep, score: recencyScore });
|
||||
}
|
||||
});
|
||||
|
||||
for (const id of entityBoostedIds) {
|
||||
if (pool.has(id)) pool.get(id).score += entityWeight;
|
||||
}
|
||||
|
||||
return [...pool.values()].sort((a, b) => b.score - a.score);
|
||||
}
|
||||
|
||||
function selectWithinBudget(scoredPool, contextBudget, minRecentEpisodes, recentEpisodes) {
|
||||
let budget = contextBudget;
|
||||
const sortByTime = (a, b) => a.created_at - b.created_at;
|
||||
|
||||
// Guarantee floor: always include the N most recent episodes
|
||||
const guaranteed = recentEpisodes.slice(0, minRecentEpisodes);
|
||||
const guaranteedIds = new Set(guaranteed.map(ep => ep.id));
|
||||
for (const ep of guaranteed) budget -= estimateTokens(ep);
|
||||
|
||||
// Fill remaining budget from scored pool, highest-priority first
|
||||
const selected = [];
|
||||
for (const { episode } of scoredPool) {
|
||||
if (guaranteedIds.has(episode.id)) continue;
|
||||
const cost = estimateTokens(episode);
|
||||
|
||||
// // Break rather than skip — lower-priority episodes aren't worth fitting over higher-priority ones
|
||||
if (budget - cost < 0) break;
|
||||
selected.push(episode);
|
||||
budget -= cost;
|
||||
}
|
||||
|
||||
return {
|
||||
guaranteed: [...guaranteed].sort(sortByTime),
|
||||
selected: selected.sort(sortByTime),
|
||||
};
|
||||
}
|
||||
|
||||
|
||||
async function getFusedEpisodes(userMessage, session, recentIds, projectSessionIds, settings) {
|
||||
const { semanticLimit, scoreThreshold, semanticWeight, keywordWeight } = settings;
|
||||
const ftsSessionIds = projectSessionIds ?? [session.id];
|
||||
|
||||
const ftsPromise = keywordWeight > 0
|
||||
// FTS and semantic may have significant overlap, so fetching more from FTS gives the fusion step more to work with before deduplication.
|
||||
? getFTSResults(userMessage, { limit: semanticLimit * 2, sessionIds: ftsSessionIds })
|
||||
: Promise.resolve([]);
|
||||
|
||||
@@ -195,7 +254,9 @@ async function getFusedEpisodes(userMessage, session, recentIds, projectSessionI
|
||||
async function assembleContext(externalId, userMessage) {
|
||||
const settings = appSettings.load();
|
||||
const { recentEpisodeLimit, semanticLimit, scoreThreshold,
|
||||
temperature, repeatPenalty, topP, topK, systemPrompt, semanticWeight, keywordWeight } = settings;
|
||||
temperature, repeatPenalty, topP, topK, systemPrompt,
|
||||
semanticWeight, keywordWeight,
|
||||
contextBudget, entityWeight, minRecentEpisodes } = settings;
|
||||
|
||||
// 1. Resolve or create session
|
||||
let session = await memory.getSessionByExternalId(externalId);
|
||||
@@ -222,24 +283,41 @@ async function assembleContext(externalId, userMessage) {
|
||||
const isFirstMessage = recentEpisodes.length === 0;
|
||||
const recentIds = new Set(recentEpisodes.map(e => e.id));
|
||||
|
||||
// 4. Semantic + entity search
|
||||
const fusedEpisodes = await getFusedEpisodes(userMessage, session, recentIds, projectSessionIds, { semanticLimit, scoreThreshold, semanticWeight, keywordWeight });
|
||||
const entityResults = await getRelevantEntities(userMessage, session.project_id ?? null);
|
||||
// 4. Fused retrieval + entity search in parallel (both are independent)
|
||||
const [fusedWithScores, entityResults] = await Promise.all([
|
||||
getFusedEpisodes(userMessage, session, recentIds, projectSessionIds, { semanticLimit, scoreThreshold, semanticWeight, keywordWeight }),
|
||||
getRelevantEntities(userMessage, session.project_id ?? null),
|
||||
]);
|
||||
|
||||
// 5. Expand matched entities into 1-hop graph neighborhood
|
||||
let neighborhood = { nodes: [], edges: [] };
|
||||
if (entityResults.length > 0) {
|
||||
// 5. Entity-linked episode IDs for scoring bonus
|
||||
const entityIds = entityResults.map(e => e.id);
|
||||
let entityBoostedIds = new Set();
|
||||
if (entityIds.length > 0) {
|
||||
try {
|
||||
neighborhood = await graph.getNeighbors(entityResults.map(e => e.id));
|
||||
const result = await memory.getEpisodesByEntities(entityIds);
|
||||
entityBoostedIds = new Set(result.episodeIds);
|
||||
} catch (err) {
|
||||
logger.debug('[orchestration] Entity-episode lookup failed, skipping bonus:', err.message);
|
||||
}
|
||||
}
|
||||
|
||||
// 6. Build unified scored pool and select within token budget
|
||||
const scoredPool = buildScoredPool(fusedWithScores, recentEpisodes, entityBoostedIds, { entityWeight });
|
||||
const { guaranteed, selected } = selectWithinBudget(scoredPool, contextBudget, minRecentEpisodes, recentEpisodes);
|
||||
|
||||
// 7. Graph neighborhood expansion
|
||||
let neighborhood = { nodes: [], edges: [] };
|
||||
if (entityIds.length > 0) {
|
||||
try {
|
||||
neighborhood = await graph.getNeighbors(entityIds);
|
||||
} catch (err) {
|
||||
logger.warn('[orchestration] Graph neighborhood fetch failed, falling back to flat entities:', err.message);
|
||||
// Graceful fallback: use Qdrant payload data as flat nodes, no edges
|
||||
neighborhood = { nodes: entityResults, edges: [] };
|
||||
}
|
||||
}
|
||||
|
||||
// 6. Assemble prompt
|
||||
const prompt = buildPrompt(recentEpisodes, fusedEpisodes, neighborhood, userMessage, activeSystemPrompt);
|
||||
// 8. Assemble prompt
|
||||
const prompt = buildPrompt(guaranteed, selected, neighborhood, userMessage, activeSystemPrompt);
|
||||
|
||||
return {
|
||||
session,
|
||||
|
||||
@@ -16,6 +16,9 @@ const DEFAULTS = {
|
||||
systemPrompt: ORCHESTRATION.SYSTEM_PROMPT,
|
||||
semanticWeight: RETRIEVAL.SEMANTIC_WEIGHT,
|
||||
keywordWeight: RETRIEVAL.KEYWORD_WEIGHT,
|
||||
contextBudget: ORCHESTRATION.CONTEXT_BUDGET,
|
||||
entityWeight: ORCHESTRATION.ENTITY_WEIGHT,
|
||||
minRecentEpisodes: ORCHESTRATION.MIN_RECENT_EPISODES,
|
||||
};
|
||||
|
||||
function load() {
|
||||
|
||||
@@ -94,6 +94,27 @@ if (req.body.systemPrompt !== undefined) {
|
||||
updates.keywordWeight = val;
|
||||
}
|
||||
|
||||
if (req.body.contextBudget !== undefined) {
|
||||
const val = Number(req.body.contextBudget);
|
||||
if (!Number.isInteger(val) || val < 512 || val > 32768)
|
||||
return res.status(400).json({ error: 'contextBudget must be 512–32768' });
|
||||
updates.contextBudget = val;
|
||||
}
|
||||
|
||||
if (req.body.entityWeight !== undefined) {
|
||||
const val = Number(req.body.entityWeight);
|
||||
if (isNaN(val) || val < 0 || val > 2)
|
||||
return res.status(400).json({ error: 'entityWeight must be 0–2' });
|
||||
updates.entityWeight = val;
|
||||
}
|
||||
|
||||
if (req.body.minRecentEpisodes !== undefined) {
|
||||
const val = Number(req.body.minRecentEpisodes);
|
||||
if (!Number.isInteger(val) || val < 0 || val > 10)
|
||||
return res.status(400).json({ error: 'minRecentEpisodes must be 0–10' });
|
||||
updates.minRecentEpisodes = val;
|
||||
}
|
||||
|
||||
res.json(settings.save(updates));
|
||||
});
|
||||
|
||||
|
||||
@@ -206,6 +206,16 @@ async function searchEpisodes(query, { limit = 10, sessionIds = null } = {}) {
|
||||
return res.json();
|
||||
}
|
||||
|
||||
async function getEpisodesByEntities(entityIds) {
|
||||
const res = await fetch(`${BASE_URL}/episodes/by-entities`, {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({ entityIds }),
|
||||
});
|
||||
if (!res.ok) throw new Error(`Episodes-by-entities error: ${res.status}`);
|
||||
return res.json(); // { episodeIds: [...] }
|
||||
}
|
||||
|
||||
module.exports = {
|
||||
getSessionByExternalId,
|
||||
createSession,
|
||||
@@ -231,4 +241,5 @@ module.exports = {
|
||||
generateProjectSummary,
|
||||
getProjectOverviewSummary,
|
||||
searchEpisodes,
|
||||
getEpisodesByEntities,
|
||||
}
|
||||
@@ -28,6 +28,9 @@ const ORCHESTRATION = {
|
||||
ENTITIES_LIMIT: 5,
|
||||
ENTITIES_THRESHOLD: 0.55,
|
||||
TEMPERATURE: 0.7,
|
||||
CONTEXT_BUDGET: 4096,
|
||||
ENTITY_WEIGHT: 0.5,
|
||||
MIN_RECENT_EPISODES: 2,
|
||||
CORS_ORIGIN: 'http://localhost:5173',
|
||||
SYSTEM_PROMPT: `You are a helpful, context-aware AI assistant. You have access to memories of past conversations with the user. Use them to provide consistent, personalised responses.`
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user