model inference settings
This commit is contained in:
@@ -264,6 +264,31 @@ function ModelsSection({ onNavigate }) {
|
|||||||
onSave={val => saveSetting('temperature', val)}
|
onSave={val => saveSetting('temperature', val)}
|
||||||
saving={saving}
|
saving={saving}
|
||||||
/>
|
/>
|
||||||
|
|
||||||
|
<NumberSetting
|
||||||
|
label="Repeat Penalty"
|
||||||
|
description="Penalises repeated tokens — higher reduces repetition (1–2)"
|
||||||
|
value={settings?.repeatPenalty}
|
||||||
|
min={1} max={2} step={0.05}
|
||||||
|
onSave={val => saveSetting('repeatPenalty', val)}
|
||||||
|
saving={saving}
|
||||||
|
/>
|
||||||
|
<NumberSetting
|
||||||
|
label="Top-P"
|
||||||
|
description="Nucleus sampling — limits token pool by cumulative probability (0–1)"
|
||||||
|
value={settings?.topP}
|
||||||
|
min={0} max={1} step={0.05}
|
||||||
|
onSave={val => saveSetting('topP', val)}
|
||||||
|
saving={saving}
|
||||||
|
/>
|
||||||
|
<NumberSetting
|
||||||
|
label="Top-K"
|
||||||
|
description="Limits token pool to K most likely tokens per step (1–100)"
|
||||||
|
value={settings?.topK}
|
||||||
|
min={1} max={100} step={1}
|
||||||
|
onSave={val => saveSetting('topK', val)}
|
||||||
|
saving={saving}
|
||||||
|
/>
|
||||||
<SettingsRow
|
<SettingsRow
|
||||||
label="Active Model"
|
label="Active Model"
|
||||||
description="Model used for inference"
|
description="Model used for inference"
|
||||||
|
|||||||
@@ -5,14 +5,14 @@ const router = Router();
|
|||||||
|
|
||||||
// Standard completion endpoint - returns full response when done
|
// Standard completion endpoint - returns full response when done
|
||||||
router.post('/complete', async (req, res) => {
|
router.post('/complete', async (req, res) => {
|
||||||
const { prompt, model, temperature, maxTokens } = req.body;
|
const { prompt, model, temperature, maxTokens, topP, topK, repeatPenalty } = req.body;
|
||||||
|
|
||||||
if (!prompt) {
|
if (!prompt) {
|
||||||
return res.status(400).json({ error: 'prompt is required'});
|
return res.status(400).json({ error: 'prompt is required'});
|
||||||
}
|
}
|
||||||
|
|
||||||
try {
|
try {
|
||||||
const result = await complete (prompt, {model, temperature, maxTokens});
|
const result = await complete (prompt, {model, temperature, maxTokens, topP, topK, repeatPenalty});
|
||||||
res.json(result);
|
res.json(result);
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
console.error('[Inference] Completion error:', error.message);
|
console.error('[Inference] Completion error:', error.message);
|
||||||
@@ -22,7 +22,7 @@ router.post('/complete', async (req, res) => {
|
|||||||
|
|
||||||
// Streaming completion endpoint - sends partial responses as they arrive
|
// Streaming completion endpoint - sends partial responses as they arrive
|
||||||
router.post('/complete/stream', async (req, res) => {
|
router.post('/complete/stream', async (req, res) => {
|
||||||
const { prompt, model, temperature } = req.body;
|
const { prompt, model, temperature, topP, topK, repeatPenalty } = req.body;
|
||||||
|
|
||||||
if (!prompt) return res.status(400).json({ error: 'prompt is required' });
|
if (!prompt) return res.status(400).json({ error: 'prompt is required' });
|
||||||
|
|
||||||
@@ -34,7 +34,7 @@ router.post('/complete/stream', async (req, res) => {
|
|||||||
let lastModel = model;
|
let lastModel = model;
|
||||||
let tokenCount = 0;
|
let tokenCount = 0;
|
||||||
|
|
||||||
for await (const chunk of completeStream(prompt, { model, temperature })) {
|
for await (const chunk of completeStream(prompt, { model, temperature, topP, topK, repeatPenalty })) {
|
||||||
if (chunk.response) {
|
if (chunk.response) {
|
||||||
res.write(`data: ${JSON.stringify({ response: chunk.response })}\n\n`);
|
res.write(`data: ${JSON.stringify({ response: chunk.response })}\n\n`);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -126,7 +126,7 @@ async function getRelevantEntities(userMessage) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
async function chat(externalId, userMessage, options = {}) {
|
async function chat(externalId, userMessage, options = {}) {
|
||||||
const { recentEpisodeLimit, semanticLimit, scoreThreshold, temperature} =
|
const { recentEpisodeLimit, semanticLimit, scoreThreshold, temperature, repeatPenalty, topP, topK} =
|
||||||
appSettings.load();
|
appSettings.load();
|
||||||
// 1. Resolve or create session
|
// 1. Resolve or create session
|
||||||
let session = await memory.getSessionByExternalId(externalId);
|
let session = await memory.getSessionByExternalId(externalId);
|
||||||
@@ -187,7 +187,7 @@ async function chat(externalId, userMessage, options = {}) {
|
|||||||
);
|
);
|
||||||
|
|
||||||
// 5. Run inference
|
// 5. Run inference
|
||||||
const result = await inference.complete(prompt, {...options, temperature});
|
const result = await inference.complete(prompt, {...options, temperature, repeatPenalty, topP, topK});
|
||||||
|
|
||||||
// 6. Write episode back to memory
|
// 6. Write episode back to memory
|
||||||
memory
|
memory
|
||||||
@@ -217,7 +217,7 @@ async function chat(externalId, userMessage, options = {}) {
|
|||||||
|
|
||||||
async function chatStream(externalId, userMessage, onChunk, options = {}) {
|
async function chatStream(externalId, userMessage, onChunk, options = {}) {
|
||||||
try {
|
try {
|
||||||
const { recentEpisodeLimit, semanticLimit, scoreThreshold, temperature } = appSettings.load();
|
const { recentEpisodeLimit, semanticLimit, scoreThreshold, temperature, repeatPenalty, topP, topK } = appSettings.load();
|
||||||
let session = await memory.getSessionByExternalId(externalId);
|
let session = await memory.getSessionByExternalId(externalId);
|
||||||
if (!session) session = await memory.createSession(externalId);
|
if (!session) session = await memory.createSession(externalId);
|
||||||
|
|
||||||
@@ -270,7 +270,7 @@ async function chatStream(externalId, userMessage, onChunk, options = {}) {
|
|||||||
entities,
|
entities,
|
||||||
userMessage,
|
userMessage,
|
||||||
);
|
);
|
||||||
const res = await inference.completeStream(prompt, {...options, temperature});
|
const res = await inference.completeStream(prompt, {...options, temperature, repeatPenalty, topP, topK});
|
||||||
|
|
||||||
let fullText = "";
|
let fullText = "";
|
||||||
let model = "";
|
let model = "";
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
const fs = require('fs');
|
const fs = require('fs');
|
||||||
const path = require('path');
|
const path = require('path');
|
||||||
const { getEnv, ORCHESTRATION } = require('@nexusai/shared');
|
const { getEnv, ORCHESTRATION, INFERENCE_DEFAULTS } = require('@nexusai/shared');
|
||||||
|
|
||||||
const SETTINGS_PATH = path.join(__dirname, '../../data/settings.json');
|
const SETTINGS_PATH = path.join(__dirname, '../../data/settings.json');
|
||||||
|
|
||||||
@@ -9,7 +9,10 @@ const DEFAULTS = {
|
|||||||
semanticLimit: ORCHESTRATION.SEMANTIC_LIMIT,
|
semanticLimit: ORCHESTRATION.SEMANTIC_LIMIT,
|
||||||
scoreThreshold: ORCHESTRATION.SCORE_THRESHOLD,
|
scoreThreshold: ORCHESTRATION.SCORE_THRESHOLD,
|
||||||
modelsFolderPath: getEnv('MODELS_MANIFEST_PATH', '/mnt/nexus-models'),
|
modelsFolderPath: getEnv('MODELS_MANIFEST_PATH', '/mnt/nexus-models'),
|
||||||
temperature: ORCHESTRATION.TEMPERATURE
|
temperature: INFERENCE_DEFAULTS.TEMPERATURE,
|
||||||
|
repeatPenalty: INFERENCE_DEFAULTS.REPEAT_PENALTY,
|
||||||
|
topP: INFERENCE_DEFAULTS.TOP_P,
|
||||||
|
topK: INFERENCE_DEFAULTS.TOP_K
|
||||||
};
|
};
|
||||||
|
|
||||||
function load() {
|
function load() {
|
||||||
|
|||||||
@@ -52,6 +52,27 @@ router.patch('/', (req, res) => {
|
|||||||
updates.temperature = val;
|
updates.temperature = val;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (req.body.repeatPenalty !== undefined) {
|
||||||
|
const val = Number(req.body.repeatPenalty);
|
||||||
|
if (isNaN(val) || val < 1 || val > 2)
|
||||||
|
return res.status(400).json({ error: 'repeatPenalty must be 1–2' });
|
||||||
|
updates.repeatPenalty = val;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (req.body.topP !== undefined) {
|
||||||
|
const val = Number(req.body.topP);
|
||||||
|
if (isNaN(val) || val < 0 || val > 1)
|
||||||
|
return res.status(400).json({ error: 'topP must be 0–1' });
|
||||||
|
updates.topP = val;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (req.body.topK !== undefined) {
|
||||||
|
const val = Number(req.body.topK);
|
||||||
|
if (!Number.isInteger(val) || val < 1 || val > 100)
|
||||||
|
return res.status(400).json({ error: 'topK must be 1–100' });
|
||||||
|
updates.topK = val;
|
||||||
|
}
|
||||||
|
|
||||||
res.json(settings.save(updates));
|
res.json(settings.save(updates));
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user