diff --git a/packages/orchestration-service/src/chat/index.js b/packages/orchestration-service/src/chat/index.js index 5c585b2..145acef 100644 --- a/packages/orchestration-service/src/chat/index.js +++ b/packages/orchestration-service/src/chat/index.js @@ -2,7 +2,7 @@ const memory = require("../services/memory"); const inference = require("../services/inference"); const embedding = require("../services/embedding"); const qdrant = require("../services/qdrant"); -const { ORCHESTRATION, logger } = require("@nexusai/shared"); +const { ORCHESTRATION, RETRIEVAL, logger } = require("@nexusai/shared"); const appSettings = require("../config/settings"); const {triggerSummary} = require('../services/summarization') const graph = require('../services/graph'); @@ -143,10 +143,59 @@ async function getRelevantEntities(userMessage, projectId = null) { } } +async function getFTSResults(userMessage, { limit, sessionIds }) { + try { + return await memory.searchEpisodes(userMessage, { limit, sessionIds }); + } catch (err) { + logger.warn('[orchestration] FTS search failed, continuing without:', err.message); + return []; + } +} + +function fuseEpisodeResults(semanticEps, keywordEps, { semanticWeight, keywordWeight, limit }) { + const k = RETRIEVAL.RRF_K; + const scores = new Map(); + + semanticEps.forEach((ep, i) => { + scores.set(ep.id, { episode: ep, score: semanticWeight / (k + i + 1) }); + }); + + keywordEps.forEach((ep, i) => { + const contrib = keywordWeight / (k + i + 1); + if (scores.has(ep.id)) { + scores.get(ep.id).score += contrib; + } else { + scores.set(ep.id, { episode: ep, score: contrib }); + } + }); + + return [...scores.values()] + .sort((a, b) => b.score - a.score) + .slice(0, limit) + .map(({ episode }) => episode); +} + +async function getFusedEpisodes(userMessage, session, recentIds, projectSessionIds, settings) { + const { semanticLimit, scoreThreshold, semanticWeight, keywordWeight } = settings; + const ftsSessionIds = projectSessionIds ?? [session.id]; + + const ftsPromise = keywordWeight > 0 + ? getFTSResults(userMessage, { limit: semanticLimit * 2, sessionIds: ftsSessionIds }) + : Promise.resolve([]); + + const [semanticEps, rawKeywordEps] = await Promise.all([ + getSemanticEpisodes(userMessage, session.id, recentIds, projectSessionIds, { semanticLimit, scoreThreshold }), + ftsPromise, + ]); + + const keywordEps = rawKeywordEps.filter(ep => !recentIds.has(ep.id)); + return fuseEpisodeResults(semanticEps, keywordEps, { semanticWeight, keywordWeight, limit: semanticLimit }); +} + async function assembleContext(externalId, userMessage) { const settings = appSettings.load(); const { recentEpisodeLimit, semanticLimit, scoreThreshold, - temperature, repeatPenalty, topP, topK, systemPrompt } = settings; + temperature, repeatPenalty, topP, topK, systemPrompt, semanticWeight, keywordWeight } = settings; // 1. Resolve or create session let session = await memory.getSessionByExternalId(externalId); @@ -174,9 +223,7 @@ async function assembleContext(externalId, userMessage) { const recentIds = new Set(recentEpisodes.map(e => e.id)); // 4. Semantic + entity search - const semanticEpisodes = await getSemanticEpisodes( - userMessage, session.id, recentIds, projectSessionIds, { semanticLimit, scoreThreshold } - ); + const fusedEpisodes = await getFusedEpisodes(userMessage, session, recentIds, projectSessionIds, { semanticLimit, scoreThreshold, semanticWeight, keywordWeight }); const entityResults = await getRelevantEntities(userMessage, session.project_id ?? null); // 5. Expand matched entities into 1-hop graph neighborhood @@ -192,7 +239,7 @@ async function assembleContext(externalId, userMessage) { } // 6. Assemble prompt - const prompt = buildPrompt(recentEpisodes, semanticEpisodes, neighborhood, userMessage, activeSystemPrompt); + const prompt = buildPrompt(recentEpisodes, fusedEpisodes, neighborhood, userMessage, activeSystemPrompt); return { session, diff --git a/packages/orchestration-service/src/config/settings.js b/packages/orchestration-service/src/config/settings.js index 6d3b5d4..157e7ea 100644 --- a/packages/orchestration-service/src/config/settings.js +++ b/packages/orchestration-service/src/config/settings.js @@ -1,6 +1,6 @@ const fs = require('fs'); const path = require('path'); -const { getEnv, ORCHESTRATION, INFERENCE_DEFAULTS } = require('@nexusai/shared'); +const { getEnv, ORCHESTRATION, INFERENCE_DEFAULTS, RETRIEVAL } = require('@nexusai/shared'); const SETTINGS_PATH = path.join(__dirname, '../../data/settings.json'); @@ -14,6 +14,8 @@ const DEFAULTS = { topP: INFERENCE_DEFAULTS.TOP_P, topK: INFERENCE_DEFAULTS.TOP_K, systemPrompt: ORCHESTRATION.SYSTEM_PROMPT, + semanticWeight: RETRIEVAL.SEMANTIC_WEIGHT, + keywordWeight: RETRIEVAL.KEYWORD_WEIGHT, }; function load() { diff --git a/packages/orchestration-service/src/routes/settings.js b/packages/orchestration-service/src/routes/settings.js index b2089cf..773f9c4 100644 --- a/packages/orchestration-service/src/routes/settings.js +++ b/packages/orchestration-service/src/routes/settings.js @@ -80,6 +80,20 @@ if (req.body.systemPrompt !== undefined) { updates.systemPrompt = val.trim() || null; // null reverts to default } + if (req.body.semanticWeight !== undefined) { + const val = Number(req.body.semanticWeight); + if (isNaN(val) || val < 0 || val > 5) + return res.status(400).json({ error: 'semanticWeight must be 0–5' }); + updates.semanticWeight = val; + } + + if (req.body.keywordWeight !== undefined) { + const val = Number(req.body.keywordWeight); + if (isNaN(val) || val < 0 || val > 5) + return res.status(400).json({ error: 'keywordWeight must be 0–5' }); + updates.keywordWeight = val; + } + res.json(settings.save(updates)); }); diff --git a/packages/orchestration-service/src/services/memory.js b/packages/orchestration-service/src/services/memory.js index 03616a8..5e48f18 100644 --- a/packages/orchestration-service/src/services/memory.js +++ b/packages/orchestration-service/src/services/memory.js @@ -196,6 +196,16 @@ async function getProjectOverviewSummary(projectId) { return res.json(); // null if none exists yet } +async function searchEpisodes(query, { limit = 10, sessionIds = null } = {}) { + const url = new URL(`${BASE_URL}/episodes/search`); + url.searchParams.set('q', query); + url.searchParams.set('limit', limit); + if (sessionIds?.length) url.searchParams.set('sessionIds', sessionIds.join(',')); + const res = await fetch(url.toString()); + if (!res.ok) throw new Error(`FTS search error: ${res.status}`); + return res.json(); +} + module.exports = { getSessionByExternalId, createSession, @@ -220,4 +230,5 @@ module.exports = { getSummariesByProject, generateProjectSummary, getProjectOverviewSummary, + searchEpisodes, } \ No newline at end of file diff --git a/test-fusion.js b/test-fusion.js new file mode 100644 index 0000000..19cc104 --- /dev/null +++ b/test-fusion.js @@ -0,0 +1,67 @@ +// test-fusion.js +const { RETRIEVAL } = require('./packages/shared/src/config/constants'); + +function fuseEpisodeResults(semanticEps, keywordEps, { semanticWeight, keywordWeight, limit }) { + const k = RETRIEVAL.RRF_K; + const scores = new Map(); + semanticEps.forEach((ep, i) => { + scores.set(ep.id, { episode: ep, score: semanticWeight / (k + i + 1) }); + }); + keywordEps.forEach((ep, i) => { + const contrib = keywordWeight / (k + i + 1); + if (scores.has(ep.id)) { + scores.get(ep.id).score += contrib; + } else { + scores.set(ep.id, { episode: ep, score: contrib }); + } + }); + return [...scores.values()] + .sort((a, b) => b.score - a.score) + .slice(0, limit) + .map(({ episode }) => episode); +} + +// --- Test 1: episodes in both lists rank highest --- +const semantic = [ + { id: 1, user_message: 'ep1 — semantic only, rank 1' }, + { id: 2, user_message: 'ep2 — in both lists, rank 2 semantic' }, + { id: 3, user_message: 'ep3 — in both lists, rank 3 semantic' }, +]; +const keyword = [ + { id: 3, user_message: 'ep3 — rank 1 FTS' }, + { id: 2, user_message: 'ep2 — rank 2 FTS' }, + { id: 4, user_message: 'ep4 — FTS only, rank 3' }, +]; + +const result = fuseEpisodeResults(semantic, keyword, { semanticWeight: 1, keywordWeight: 1, limit: 5 }); +console.log('Test 1 — equal weights, episodes in both lists should rank highest:'); +result.forEach((ep, i) => console.log(` ${i + 1}. id=${ep.id} "${ep.user_message}"`)); +console.assert(result[0].id === 2 || result[0].id === 3, 'FAIL: ep2 or ep3 should be rank 1'); +console.assert(!result.find(e => e.id === 1) || result.indexOf(result.find(e => e.id === 1)) > result.indexOf(result.find(e => e.id === 2)), 'FAIL: ep1 (semantic only) should rank below ep2'); +console.log(' PASS\n'); + +// --- Test 2: keywordWeight:0 → pure semantic passthrough --- +const result2 = fuseEpisodeResults(semantic, keyword, { semanticWeight: 1, keywordWeight: 0, limit: 5 }); +console.log('Test 2 — keywordWeight:0 should return only semantic results in original order:'); +result2.forEach((ep, i) => console.log(` ${i + 1}. id=${ep.id}`)); +console.assert(result2.length === 3, `FAIL: expected 3, got ${result2.length}`); +console.assert(result2[0].id === 1, 'FAIL: ep1 should be rank 1'); +console.assert(result2[1].id === 2, 'FAIL: ep2 should be rank 2'); +console.log(' PASS\n'); + +// --- Test 3: limit is respected --- +const result3 = fuseEpisodeResults(semantic, keyword, { semanticWeight: 1, keywordWeight: 1, limit: 2 }); +console.log('Test 3 — limit:2 should return exactly 2 results:'); +console.assert(result3.length === 2, `FAIL: expected 2, got ${result3.length}`); +console.log(' PASS\n'); + +// --- Test 4: no overlap → all unique episodes, ordered by individual contribution --- +const semOnly = [{ id: 10, user_message: 'sem' }]; +const ftsOnly = [{ id: 20, user_message: 'fts' }]; +const result4 = fuseEpisodeResults(semOnly, ftsOnly, { semanticWeight: 1, keywordWeight: 1, limit: 5 }); +console.log('Test 4 — no overlap, both should appear:'); +console.assert(result4.length === 2, `FAIL: expected 2, got ${result4.length}`); +console.assert(result4[0].id === 10, 'FAIL: semantic rank-1 should beat fts rank-1 (same weight, both rank 1, but semantic inserted first — tie goes to semantic)'); +console.log(' PASS\n'); + +console.log('All tests passed.');