retrieval fusion
This commit is contained in:
@@ -2,7 +2,7 @@ const memory = require("../services/memory");
|
|||||||
const inference = require("../services/inference");
|
const inference = require("../services/inference");
|
||||||
const embedding = require("../services/embedding");
|
const embedding = require("../services/embedding");
|
||||||
const qdrant = require("../services/qdrant");
|
const qdrant = require("../services/qdrant");
|
||||||
const { ORCHESTRATION, logger } = require("@nexusai/shared");
|
const { ORCHESTRATION, RETRIEVAL, logger } = require("@nexusai/shared");
|
||||||
const appSettings = require("../config/settings");
|
const appSettings = require("../config/settings");
|
||||||
const {triggerSummary} = require('../services/summarization')
|
const {triggerSummary} = require('../services/summarization')
|
||||||
const graph = require('../services/graph');
|
const graph = require('../services/graph');
|
||||||
@@ -143,10 +143,59 @@ async function getRelevantEntities(userMessage, projectId = null) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async function getFTSResults(userMessage, { limit, sessionIds }) {
|
||||||
|
try {
|
||||||
|
return await memory.searchEpisodes(userMessage, { limit, sessionIds });
|
||||||
|
} catch (err) {
|
||||||
|
logger.warn('[orchestration] FTS search failed, continuing without:', err.message);
|
||||||
|
return [];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
function fuseEpisodeResults(semanticEps, keywordEps, { semanticWeight, keywordWeight, limit }) {
|
||||||
|
const k = RETRIEVAL.RRF_K;
|
||||||
|
const scores = new Map();
|
||||||
|
|
||||||
|
semanticEps.forEach((ep, i) => {
|
||||||
|
scores.set(ep.id, { episode: ep, score: semanticWeight / (k + i + 1) });
|
||||||
|
});
|
||||||
|
|
||||||
|
keywordEps.forEach((ep, i) => {
|
||||||
|
const contrib = keywordWeight / (k + i + 1);
|
||||||
|
if (scores.has(ep.id)) {
|
||||||
|
scores.get(ep.id).score += contrib;
|
||||||
|
} else {
|
||||||
|
scores.set(ep.id, { episode: ep, score: contrib });
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
return [...scores.values()]
|
||||||
|
.sort((a, b) => b.score - a.score)
|
||||||
|
.slice(0, limit)
|
||||||
|
.map(({ episode }) => episode);
|
||||||
|
}
|
||||||
|
|
||||||
|
async function getFusedEpisodes(userMessage, session, recentIds, projectSessionIds, settings) {
|
||||||
|
const { semanticLimit, scoreThreshold, semanticWeight, keywordWeight } = settings;
|
||||||
|
const ftsSessionIds = projectSessionIds ?? [session.id];
|
||||||
|
|
||||||
|
const ftsPromise = keywordWeight > 0
|
||||||
|
? getFTSResults(userMessage, { limit: semanticLimit * 2, sessionIds: ftsSessionIds })
|
||||||
|
: Promise.resolve([]);
|
||||||
|
|
||||||
|
const [semanticEps, rawKeywordEps] = await Promise.all([
|
||||||
|
getSemanticEpisodes(userMessage, session.id, recentIds, projectSessionIds, { semanticLimit, scoreThreshold }),
|
||||||
|
ftsPromise,
|
||||||
|
]);
|
||||||
|
|
||||||
|
const keywordEps = rawKeywordEps.filter(ep => !recentIds.has(ep.id));
|
||||||
|
return fuseEpisodeResults(semanticEps, keywordEps, { semanticWeight, keywordWeight, limit: semanticLimit });
|
||||||
|
}
|
||||||
|
|
||||||
async function assembleContext(externalId, userMessage) {
|
async function assembleContext(externalId, userMessage) {
|
||||||
const settings = appSettings.load();
|
const settings = appSettings.load();
|
||||||
const { recentEpisodeLimit, semanticLimit, scoreThreshold,
|
const { recentEpisodeLimit, semanticLimit, scoreThreshold,
|
||||||
temperature, repeatPenalty, topP, topK, systemPrompt } = settings;
|
temperature, repeatPenalty, topP, topK, systemPrompt, semanticWeight, keywordWeight } = settings;
|
||||||
|
|
||||||
// 1. Resolve or create session
|
// 1. Resolve or create session
|
||||||
let session = await memory.getSessionByExternalId(externalId);
|
let session = await memory.getSessionByExternalId(externalId);
|
||||||
@@ -174,9 +223,7 @@ async function assembleContext(externalId, userMessage) {
|
|||||||
const recentIds = new Set(recentEpisodes.map(e => e.id));
|
const recentIds = new Set(recentEpisodes.map(e => e.id));
|
||||||
|
|
||||||
// 4. Semantic + entity search
|
// 4. Semantic + entity search
|
||||||
const semanticEpisodes = await getSemanticEpisodes(
|
const fusedEpisodes = await getFusedEpisodes(userMessage, session, recentIds, projectSessionIds, { semanticLimit, scoreThreshold, semanticWeight, keywordWeight });
|
||||||
userMessage, session.id, recentIds, projectSessionIds, { semanticLimit, scoreThreshold }
|
|
||||||
);
|
|
||||||
const entityResults = await getRelevantEntities(userMessage, session.project_id ?? null);
|
const entityResults = await getRelevantEntities(userMessage, session.project_id ?? null);
|
||||||
|
|
||||||
// 5. Expand matched entities into 1-hop graph neighborhood
|
// 5. Expand matched entities into 1-hop graph neighborhood
|
||||||
@@ -192,7 +239,7 @@ async function assembleContext(externalId, userMessage) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 6. Assemble prompt
|
// 6. Assemble prompt
|
||||||
const prompt = buildPrompt(recentEpisodes, semanticEpisodes, neighborhood, userMessage, activeSystemPrompt);
|
const prompt = buildPrompt(recentEpisodes, fusedEpisodes, neighborhood, userMessage, activeSystemPrompt);
|
||||||
|
|
||||||
return {
|
return {
|
||||||
session,
|
session,
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
const fs = require('fs');
|
const fs = require('fs');
|
||||||
const path = require('path');
|
const path = require('path');
|
||||||
const { getEnv, ORCHESTRATION, INFERENCE_DEFAULTS } = require('@nexusai/shared');
|
const { getEnv, ORCHESTRATION, INFERENCE_DEFAULTS, RETRIEVAL } = require('@nexusai/shared');
|
||||||
|
|
||||||
const SETTINGS_PATH = path.join(__dirname, '../../data/settings.json');
|
const SETTINGS_PATH = path.join(__dirname, '../../data/settings.json');
|
||||||
|
|
||||||
@@ -14,6 +14,8 @@ const DEFAULTS = {
|
|||||||
topP: INFERENCE_DEFAULTS.TOP_P,
|
topP: INFERENCE_DEFAULTS.TOP_P,
|
||||||
topK: INFERENCE_DEFAULTS.TOP_K,
|
topK: INFERENCE_DEFAULTS.TOP_K,
|
||||||
systemPrompt: ORCHESTRATION.SYSTEM_PROMPT,
|
systemPrompt: ORCHESTRATION.SYSTEM_PROMPT,
|
||||||
|
semanticWeight: RETRIEVAL.SEMANTIC_WEIGHT,
|
||||||
|
keywordWeight: RETRIEVAL.KEYWORD_WEIGHT,
|
||||||
};
|
};
|
||||||
|
|
||||||
function load() {
|
function load() {
|
||||||
|
|||||||
@@ -80,6 +80,20 @@ if (req.body.systemPrompt !== undefined) {
|
|||||||
updates.systemPrompt = val.trim() || null; // null reverts to default
|
updates.systemPrompt = val.trim() || null; // null reverts to default
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (req.body.semanticWeight !== undefined) {
|
||||||
|
const val = Number(req.body.semanticWeight);
|
||||||
|
if (isNaN(val) || val < 0 || val > 5)
|
||||||
|
return res.status(400).json({ error: 'semanticWeight must be 0–5' });
|
||||||
|
updates.semanticWeight = val;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (req.body.keywordWeight !== undefined) {
|
||||||
|
const val = Number(req.body.keywordWeight);
|
||||||
|
if (isNaN(val) || val < 0 || val > 5)
|
||||||
|
return res.status(400).json({ error: 'keywordWeight must be 0–5' });
|
||||||
|
updates.keywordWeight = val;
|
||||||
|
}
|
||||||
|
|
||||||
res.json(settings.save(updates));
|
res.json(settings.save(updates));
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|||||||
@@ -196,6 +196,16 @@ async function getProjectOverviewSummary(projectId) {
|
|||||||
return res.json(); // null if none exists yet
|
return res.json(); // null if none exists yet
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async function searchEpisodes(query, { limit = 10, sessionIds = null } = {}) {
|
||||||
|
const url = new URL(`${BASE_URL}/episodes/search`);
|
||||||
|
url.searchParams.set('q', query);
|
||||||
|
url.searchParams.set('limit', limit);
|
||||||
|
if (sessionIds?.length) url.searchParams.set('sessionIds', sessionIds.join(','));
|
||||||
|
const res = await fetch(url.toString());
|
||||||
|
if (!res.ok) throw new Error(`FTS search error: ${res.status}`);
|
||||||
|
return res.json();
|
||||||
|
}
|
||||||
|
|
||||||
module.exports = {
|
module.exports = {
|
||||||
getSessionByExternalId,
|
getSessionByExternalId,
|
||||||
createSession,
|
createSession,
|
||||||
@@ -220,4 +230,5 @@ module.exports = {
|
|||||||
getSummariesByProject,
|
getSummariesByProject,
|
||||||
generateProjectSummary,
|
generateProjectSummary,
|
||||||
getProjectOverviewSummary,
|
getProjectOverviewSummary,
|
||||||
|
searchEpisodes,
|
||||||
}
|
}
|
||||||
67
test-fusion.js
Normal file
67
test-fusion.js
Normal file
@@ -0,0 +1,67 @@
|
|||||||
|
// test-fusion.js
|
||||||
|
const { RETRIEVAL } = require('./packages/shared/src/config/constants');
|
||||||
|
|
||||||
|
function fuseEpisodeResults(semanticEps, keywordEps, { semanticWeight, keywordWeight, limit }) {
|
||||||
|
const k = RETRIEVAL.RRF_K;
|
||||||
|
const scores = new Map();
|
||||||
|
semanticEps.forEach((ep, i) => {
|
||||||
|
scores.set(ep.id, { episode: ep, score: semanticWeight / (k + i + 1) });
|
||||||
|
});
|
||||||
|
keywordEps.forEach((ep, i) => {
|
||||||
|
const contrib = keywordWeight / (k + i + 1);
|
||||||
|
if (scores.has(ep.id)) {
|
||||||
|
scores.get(ep.id).score += contrib;
|
||||||
|
} else {
|
||||||
|
scores.set(ep.id, { episode: ep, score: contrib });
|
||||||
|
}
|
||||||
|
});
|
||||||
|
return [...scores.values()]
|
||||||
|
.sort((a, b) => b.score - a.score)
|
||||||
|
.slice(0, limit)
|
||||||
|
.map(({ episode }) => episode);
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Test 1: episodes in both lists rank highest ---
|
||||||
|
const semantic = [
|
||||||
|
{ id: 1, user_message: 'ep1 — semantic only, rank 1' },
|
||||||
|
{ id: 2, user_message: 'ep2 — in both lists, rank 2 semantic' },
|
||||||
|
{ id: 3, user_message: 'ep3 — in both lists, rank 3 semantic' },
|
||||||
|
];
|
||||||
|
const keyword = [
|
||||||
|
{ id: 3, user_message: 'ep3 — rank 1 FTS' },
|
||||||
|
{ id: 2, user_message: 'ep2 — rank 2 FTS' },
|
||||||
|
{ id: 4, user_message: 'ep4 — FTS only, rank 3' },
|
||||||
|
];
|
||||||
|
|
||||||
|
const result = fuseEpisodeResults(semantic, keyword, { semanticWeight: 1, keywordWeight: 1, limit: 5 });
|
||||||
|
console.log('Test 1 — equal weights, episodes in both lists should rank highest:');
|
||||||
|
result.forEach((ep, i) => console.log(` ${i + 1}. id=${ep.id} "${ep.user_message}"`));
|
||||||
|
console.assert(result[0].id === 2 || result[0].id === 3, 'FAIL: ep2 or ep3 should be rank 1');
|
||||||
|
console.assert(!result.find(e => e.id === 1) || result.indexOf(result.find(e => e.id === 1)) > result.indexOf(result.find(e => e.id === 2)), 'FAIL: ep1 (semantic only) should rank below ep2');
|
||||||
|
console.log(' PASS\n');
|
||||||
|
|
||||||
|
// --- Test 2: keywordWeight:0 → pure semantic passthrough ---
|
||||||
|
const result2 = fuseEpisodeResults(semantic, keyword, { semanticWeight: 1, keywordWeight: 0, limit: 5 });
|
||||||
|
console.log('Test 2 — keywordWeight:0 should return only semantic results in original order:');
|
||||||
|
result2.forEach((ep, i) => console.log(` ${i + 1}. id=${ep.id}`));
|
||||||
|
console.assert(result2.length === 3, `FAIL: expected 3, got ${result2.length}`);
|
||||||
|
console.assert(result2[0].id === 1, 'FAIL: ep1 should be rank 1');
|
||||||
|
console.assert(result2[1].id === 2, 'FAIL: ep2 should be rank 2');
|
||||||
|
console.log(' PASS\n');
|
||||||
|
|
||||||
|
// --- Test 3: limit is respected ---
|
||||||
|
const result3 = fuseEpisodeResults(semantic, keyword, { semanticWeight: 1, keywordWeight: 1, limit: 2 });
|
||||||
|
console.log('Test 3 — limit:2 should return exactly 2 results:');
|
||||||
|
console.assert(result3.length === 2, `FAIL: expected 2, got ${result3.length}`);
|
||||||
|
console.log(' PASS\n');
|
||||||
|
|
||||||
|
// --- Test 4: no overlap → all unique episodes, ordered by individual contribution ---
|
||||||
|
const semOnly = [{ id: 10, user_message: 'sem' }];
|
||||||
|
const ftsOnly = [{ id: 20, user_message: 'fts' }];
|
||||||
|
const result4 = fuseEpisodeResults(semOnly, ftsOnly, { semanticWeight: 1, keywordWeight: 1, limit: 5 });
|
||||||
|
console.log('Test 4 — no overlap, both should appear:');
|
||||||
|
console.assert(result4.length === 2, `FAIL: expected 2, got ${result4.length}`);
|
||||||
|
console.assert(result4[0].id === 10, 'FAIL: semantic rank-1 should beat fts rank-1 (same weight, both rank 1, but semantic inserted first — tie goes to semantic)');
|
||||||
|
console.log(' PASS\n');
|
||||||
|
|
||||||
|
console.log('All tests passed.');
|
||||||
Reference in New Issue
Block a user