package middleware import ( "context" "crypto/rand" "encoding/hex" "fmt" "hash/fnv" "strings" "time" "github.com/labstack/echo/v4" "github.com/mudler/LocalAI/core/backend" "github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/http/auth" "github.com/mudler/LocalAI/core/schema" "github.com/mudler/LocalAI/core/services/routing/router" "github.com/mudler/LocalAI/core/templates" "github.com/mudler/xlog" "gopkg.in/yaml.v3" ) // ScorerFactory returns a backend.Scorer bound to a named classifier // model. The score classifier uses it to compute joint log-prob of // every policy label against the routing prompt. type ScorerFactory func(modelName string) backend.Scorer // EmbedderFactory returns a backend.Embedder bound to a named model. // Used by the L2 embedding cache. Returning nil signals "model not // loadable" — the middleware then falls back to the uncached // classifier so routing still happens. type EmbedderFactory func(modelName string) backend.Embedder // VectorStoreFactory returns a backend.VectorStore bound to a named // collection. Each router model's cache lives in its own collection // so two routers can't poison each other's hits. type VectorStoreFactory func(storeName string) backend.VectorStore // RerankerFactory returns a backend.Reranker bound to a named model. // Used by the colbert classifier to score policy descriptions against // the prompt via LocalAI's rerankers backend. Returning nil signals // "model not loadable" — buildClassifier reports a config error. type RerankerFactory func(modelName string) backend.Reranker // ModelConfigLookup resolves a model name to its config, or nil when // unknown. Used by buildClassifier to confirm the classifier_model // declared the score usecase — the actual usecase-conflict check // lives in ModelConfig.Validate() and runs at config load/save time. type ModelConfigLookup func(modelName string) *config.ModelConfig // ClassifierDeps bundles the backend factories the router middleware // needs to build a classifier and its optional L2 cache. Bundled into // one struct because RouteModel already takes many positional // arguments — additions to the dependency surface go here instead of // growing the signature. // // Embedder and VectorStore are optional: when both are non-nil and the // router config declares an embedding_cache block, the score // classifier is wrapped in EmbeddingCacheClassifier. Otherwise the // score classifier runs unwrapped and the embedding-cache YAML is // ignored with a warning. type ClassifierDeps struct { Scorer ScorerFactory Embedder EmbedderFactory VectorStore VectorStoreFactory Reranker RerankerFactory // ModelLookup resolves the classifier_model name to its config so // buildClassifier can reject misconfigurations that would // otherwise crash the llama-cpp backend at request time. Optional // — when nil, the check is skipped (tests, embedded callers that // haven't wired the loader). ModelLookup ModelConfigLookup // Registry is the shared classifier cache. Both the OpenAI and // Anthropic routes pass the same registry so the admin stats // endpoint sees every live classifier. Nil falls back to a local // registry — tests that don't need cross-route stats use this. Registry *router.Registry // Evaluator renders the classifier model's chat template around // the routing system + user prompt. Optional — when nil, the // score classifier falls back to a built-in ChatML envelope, // which is correct for Arch-Router/Qwen but wrong for non-ChatML // routing models. Production wiring passes the app-wide // templates.Evaluator so any model the operator points at gets // its own chat template applied. Evaluator *templates.Evaluator } // ProbeExtractor pulls the prompt content out of a parsed request so // the classifier can inspect it without taking a dependency on the // schema package. One extractor per request shape — wired by the // route registration site (mirrors the piiadapter pattern). // // Returns ok=false when the parsed value isn't the expected type — the // middleware then passes through without engaging the router. type ProbeExtractor func(parsed any) (router.Probe, bool) // RouteModel runs after SetModelAndConfig and the schema-specific // SetXRequest, looks at the resolved model's Router config, and (when // present) reclassifies the request to one of the candidates. // // The middleware: // // 1. Loads MODEL_CONFIG from the echo context. If nil or HasRouter() // is false, passes through. // 2. Extracts the probe via the supplied ProbeExtractor. // 3. Invokes the classifier matching cfg.Router.Classifier // ("score" or "colbert"). If the classifier can't be built — // missing classifier_model, misconfigured policies, etc. — the // request fails with 503. cfg.Router.Fallback only catches // Classify-time errors and label-coverage misses, not config // bugs that would otherwise be silent. // 4. Resolves the chosen candidate to its model name. Reloads the // ModelConfig for that model and asserts depth-1 (the candidate // must NOT itself have a Router). Violation returns 500 — config // bug, not a request bug. // 5. Updates input.Model in place, replaces MODEL_CONFIG with the // candidate's config, and stamps RequestedModel/ServedModel on the // context so UsageMiddleware records the routing. // 6. Writes a DecisionRecord to the store for the admin page. // // store may be nil when --disable-stats turns off the routing log; // classification still runs. // // Composition with SmartRouter (distributed mode): this middleware // only does *model* selection. Node selection still happens in // SmartRouter.Route() downstream of this middleware. // RouteModel wires the router middleware. source is the value written to // DecisionRecord.Source (router.SourceChat / SourceAnthropic / ...) so // the admin page can split decisions by entry point. Pass // router.SourceChat for the OpenAI chat endpoint, router.SourceAnthropic // for the Anthropic messages endpoint. func RouteModel(loader *config.ModelConfigLoader, appConfig *config.ApplicationConfig, store router.DecisionStore, fallbackUser *auth.User, extractor ProbeExtractor, source string, deps ClassifierDeps) echo.MiddlewareFunc { registry := deps.Registry if registry == nil { registry = router.NewRegistry() } candidateLoader := func(name string) (*config.ModelConfig, error) { return loader.LoadModelConfigFileByNameDefaultOptions(name, appConfig) } return func(next echo.HandlerFunc) echo.HandlerFunc { return func(c echo.Context) error { cfg, ok := c.Get(CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig) if !ok || cfg == nil || !cfg.HasRouter() { return next(c) } parsed := c.Get(CONTEXT_LOCALS_KEY_LOCALAI_REQUEST) if parsed == nil { return next(c) } probe, probeOK := extractor(parsed) if !probeOK { return next(c) } classifier, err := GetOrBuildClassifier(registry, cfg, deps) if err != nil { // Build-time failures are config bugs (missing // classifier_model, undeclared usecase, policy // validation, ...). Silently falling back would hide // them and make the router look "working" while the // classifier model is never invoked — surface as 503 // with the underlying reason so operators see it. xlog.Warn("router: classifier build failed", "router_model", cfg.Name, "classifier", cfg.Router.Classifier, "error", err) return echo.NewHTTPError(503, "router classifier unavailable: "+err.Error()) } result, err := router.Resolve(c.Request().Context(), cfg, classifier, candidateLoader, probe) if err != nil { xlog.Warn("router: resolve failed", "router_model", cfg.Name, "error", err) return echo.NewHTTPError(500, err.Error()) } if req, ok := parsed.(schema.LocalAIRequest); ok { chosen := result.ChosenModel req.ModelName(&chosen) } c.Set(CONTEXT_LOCALS_KEY_MODEL_CONFIG, result.ChosenConfig) c.Set(ContextKeyRequestedModel, result.RouterModel) c.Set(ContextKeyServedModel, result.ChosenModel) if store != nil { recordHTTPDecision(c, store, result, fallbackUser, source) } return next(c) } } } // recordHTTPDecision writes the resolved decision to the store with // HTTP-shaped audit metadata (correlation id from header, user from // auth middleware, fallback to the synthetic local user). Realtime // has its own recorder that supplies session-derived metadata // instead. func recordHTTPDecision(c echo.Context, store router.DecisionStore, result *router.ResolveResult, fallbackUser *auth.User, source string) { correlationID, _ := c.Get(ContextKeyCorrelationID).(string) if correlationID == "" { correlationID = c.Response().Header().Get("X-Correlation-ID") } userID := "" if u := auth.GetUser(c); u != nil { userID = u.ID } else if fallbackUser != nil { userID = fallbackUser.ID } _ = store.Record(context.Background(), result.ToDecisionRecord(newDecisionID(), correlationID, userID, source)) } // GetOrBuildClassifier looks up a built Classifier for the named router // model in the registry and builds it on miss. Exported so the // /api/router/decide decision-oracle endpoint can share the same // build-once cache that the in-band RouteModel middleware uses. func GetOrBuildClassifier(registry *router.Registry, cfg *config.ModelConfig, deps ClassifierDeps) (router.Classifier, error) { // Fingerprint folds the classifier model's renderer-affecting // fields (chat templates + stopwords) in alongside the router // config. Without this, hot-reloading the classifier model's // YAML (via ReloadModelsEndpoint, /import-model, or the MCP // reload_models tool) wouldn't rebuild the cached classifier — // the candidates slice and renderer closure are baked at build // time from those fields and would silently keep the stale // stop token / template until process restart. var classifierCfg *config.ModelConfig if deps.ModelLookup != nil { classifierCfg = deps.ModelLookup(cfg.Router.ClassifierModel) } fp := routerConfigFingerprint(cfg.Router, classifierCfg) if cached, ok := registry.Get(cfg.Name, fp); ok { return cached, nil } c, err := buildClassifier(cfg, deps) if err != nil { return nil, err } registry.Put(cfg.Name, fp, c) return c, nil } // routerConfigFingerprint is a stable cache key for the (router cfg, // classifier model cfg) tuple. FNV-64 over the YAML form of the // router block plus the renderer-affecting fields of the classifier // model — equality-only, not cryptographic. YAML-marshal picks up // any future RouterConfig field without this function needing to be // touched; for the classifier model we hash a narrow projection so // unrelated changes (parameters, files, ...) don't burst the cache. // Pass classifierCfg=nil when no lookup is wired — the fingerprint // degenerates to the router-only form, matching pre-refactor behaviour. func routerConfigFingerprint(rc config.RouterConfig, classifierCfg *config.ModelConfig) uint64 { bytes, err := yaml.Marshal(rc) if err != nil { // Marshalling a value type can't fail in practice; fall // back to a hash that varies per call so we don't quietly // share a cache entry across distinct configs. return uint64(time.Now().UnixNano()) } h := fnv.New64a() h.Write(bytes) if classifierCfg != nil { // Narrow projection: only the fields newTemplateRenderer and // firstStopWord actually read. Hashing the whole ModelConfig // would invalidate the cache on irrelevant parameter changes. h.Write([]byte{0}) // separator so empty fields don't collide h.Write([]byte(classifierCfg.TemplateConfig.Chat)) h.Write([]byte{0}) h.Write([]byte(classifierCfg.TemplateConfig.ChatMessage)) h.Write([]byte{0}) for _, sw := range classifierCfg.StopWords { h.Write([]byte(sw)) h.Write([]byte{0}) } } return h.Sum64() } func buildClassifier(cfg *config.ModelConfig, deps ClassifierDeps) (router.Classifier, error) { rc := cfg.Router name := rc.Classifier if name == "" { name = router.ClassifierScore } policies, err := validateRouterPolicies(name, rc) if err != nil { return nil, err } cacheCap := rc.ClassifierCacheSize if cacheCap == 0 { cacheCap = 1024 } var inner router.Classifier switch name { case router.ClassifierScore: if deps.Scorer == nil { return nil, fmt.Errorf("router classifier score unavailable: no scorer factory wired") } if err := assertClassifierDeclaresScore(rc.ClassifierModel, deps.ModelLookup); err != nil { return nil, err } scorer := deps.Scorer(rc.ClassifierModel) if scorer == nil { return nil, fmt.Errorf("router classifier score: classifier_model %q not loadable", rc.ClassifierModel) } opts := router.ScoreClassifierOptions{ CacheCap: cacheCap, ActivationThreshold: rc.ActivationThreshold, Normalization: rc.ScoreNormalization, SystemPromptTemplate: rc.ClassifierSystemTemplate, } // Build the prompt renderer + stop token from the classifier // model's own config when available. Without ModelLookup // (tests, embedded callers) the score classifier's built-in // ChatML defaults kick in, which is correct for Arch-Router. if deps.ModelLookup != nil { if classifierCfg := deps.ModelLookup(rc.ClassifierModel); classifierCfg != nil { if deps.Evaluator != nil { opts.PromptRenderer = newTemplateRenderer(deps.Evaluator, classifierCfg) } if st := pickAssistantTurnEnd(classifierCfg.StopWords, classifierCfg.TemplateConfig.ChatMessage); st != "" { opts.StopToken = st } } } inner = router.NewScoreClassifier(policies, scorer, opts) case router.ClassifierColbert: if deps.Reranker == nil { return nil, fmt.Errorf("router classifier colbert unavailable: no reranker factory wired") } reranker := deps.Reranker(rc.ClassifierModel) if reranker == nil { return nil, fmt.Errorf("router classifier colbert: classifier_model %q not loadable", rc.ClassifierModel) } inner = router.NewRerankClassifier(policies, reranker, cacheCap, rc.ActivationThreshold) default: return nil, fmt.Errorf("router: unknown classifier %q (supported: %s)", name, strings.Join([]string{router.ClassifierScore, router.ClassifierColbert}, ", ")) } if rc.EmbeddingCache == nil { return inner, nil } wrapped, err := wrapWithEmbeddingCache(cfg, inner, deps) if err != nil { // Caching plumbing problems must not break routing — log, // drop the cache layer, and return the uncached classifier. // The admin UI surfaces the warning via the classifier-build // error path used elsewhere. xlog.Warn("router: embedding cache disabled", "router_model", cfg.Name, "error", err) return inner, nil } return wrapped, nil } // assertClassifierDeclaresScore refuses to build the score classifier // unless classifier_model's config declares FLAG_SCORE. The actual // usecase-conflict check (score + chat/completion/embeddings on // llama-cpp) lives in ModelConfig.Validate() and fires at config load // and save time — by the time we get here, any model that reached the // loader is already conflict-free. This check just refuses to bind a // model that never declared itself for Score in the first place; that // model could be a misconfigured chat model the operator pointed at // by accident, and without FLAG_SCORE the validator never saw it. // // When lookup is nil (test wiring) the check is skipped and we fall // back to the C++ backend's runtime tripwire as the last line of // defence. func assertClassifierDeclaresScore(classifierModel string, lookup ModelConfigLookup) error { if lookup == nil { return nil } cfg := lookup(classifierModel) if cfg == nil { // Unknown model — Scorer() will produce a clearer "not // loadable" error a few lines down. return nil } if !cfg.HasUsecases(config.FLAG_SCORE) { return fmt.Errorf( "router classifier score: classifier_model %q does not declare the "+ "score usecase. Add `known_usecases: [score]` to its config so "+ "the loader can reject conflicting usecase combinations", classifierModel) } return nil } // validateRouterPolicies checks the shared invariants both classifiers // rely on (non-empty policies, every candidate label declared as a // policy, every candidate has a model + at least one label) and // returns the parsed []ScorePolicy. Both Score and Rerank classifiers // take the same policy shape. func validateRouterPolicies(classifierName string, rc config.RouterConfig) ([]router.ScorePolicy, error) { if rc.ClassifierModel == "" { return nil, fmt.Errorf("router classifier %s requires classifier_model", classifierName) } if len(rc.Policies) == 0 { return nil, fmt.Errorf("router classifier %s requires at least one policy", classifierName) } policies := make([]router.ScorePolicy, 0, len(rc.Policies)) for _, p := range rc.Policies { if p.Label == "" { return nil, fmt.Errorf("router classifier %s: policy with empty label", classifierName) } if p.Description == "" { return nil, fmt.Errorf("router classifier %s: policy %q has no description", classifierName, p.Label) } policies = append(policies, router.ScorePolicy{Label: p.Label, Description: p.Description}) } policyLabels := make(map[string]struct{}, len(policies)) for _, p := range policies { policyLabels[p.Label] = struct{}{} } for _, c := range rc.Candidates { if c.Model == "" { return nil, fmt.Errorf("router classifier %s: candidate has empty model field", classifierName) } if len(c.Labels) == 0 { return nil, fmt.Errorf("router classifier %s: candidate %q has no labels", classifierName, c.Model) } for _, l := range c.Labels { if _, ok := policyLabels[l]; !ok { return nil, fmt.Errorf("router classifier %s: candidate %q references unknown label %q (not in policies)", classifierName, c.Model, l) } } } return policies, nil } // newTemplateRenderer adapts the templates.Evaluator + the classifier // model's config into the router.PromptRenderer callback. The // resulting renderer pushes the routing system + user prompt through // the classifier model's full chat-template pipeline — per-role // formatting via TemplateConfig.ChatMessage, then the outer // TemplateConfig.Chat — so non-ChatML routing models render // correctly without router-package awareness of the template format. // // We must go through TemplateMessages, not EvaluateTemplateForPrompt // directly: the gallery's outer Chat templates are uniformly // `{{.Input -}}<|im_start|>assistant` (or the Llama-3 equivalent) // and reference {{.Input}} only — never {{.SystemPrompt}}. Passing // our routing system prompt through .SystemPrompt would silently // drop it because Go text/template ignores unreferenced fields. // TemplateMessages instead renders each role through ChatMessage and // joins them into the .Input the outer template DOES read. // // Returns nil (forcing the score classifier's chatMLRenderer // fallback) when either template piece is missing — partial // templating would still drop content. func newTemplateRenderer(eval *templates.Evaluator, classifierCfg *config.ModelConfig) router.PromptRenderer { if classifierCfg.TemplateConfig.Chat == "" || classifierCfg.TemplateConfig.ChatMessage == "" { return nil } cfgCopy := *classifierCfg return func(system, user string) (string, error) { messages := []schema.Message{ {Role: "system", StringContent: system}, {Role: "user", StringContent: user}, } rendered := eval.TemplateMessages(schema.OpenAIRequest{}, messages, &cfgCopy, nil, false) if rendered == "" { return "", fmt.Errorf("router: classifier %q chat template produced empty output", cfgCopy.Name) } return rendered, nil } } // pickAssistantTurnEnd returns the classifier model's assistant // turn-end token — the one to suffix candidates with so the model's // "I'm done" signal folds into the per-candidate joint log-prob. // // Strategy: prefer the stopword that *literally appears* in the // chat_message template, because that token is the assistant // turn-end by construction. ChatML's chat_message ends with // "<|im_end|>", Llama-3's ends with "<|eot_id|>", etc. — the // template is the source of truth. // // Fallback: the first non-empty stopword. That's right for // well-ordered configs (ChatML conventionally lists <|im_end|> // first) but wrong for some gallery Llama-3 templates that defensively // list <|im_end|> first even though the actual turn-end is <|eot_id|>. // The template-scan above catches those. // // When no stopwords are configured at all, return "" — caller falls // back to defaultStopToken (<|im_end|>) inside the score classifier. func pickAssistantTurnEnd(words []string, chatMessageTemplate string) string { if chatMessageTemplate != "" { for _, w := range words { if w != "" && strings.Contains(chatMessageTemplate, w) { return w } } } for _, w := range words { if w != "" { return w } } return "" } func wrapWithEmbeddingCache(cfg *config.ModelConfig, inner router.Classifier, deps ClassifierDeps) (router.Classifier, error) { ec := cfg.Router.EmbeddingCache if ec.EmbeddingModel == "" { return nil, fmt.Errorf("embedding_cache requires embedding_model") } if deps.Embedder == nil || deps.VectorStore == nil { return nil, fmt.Errorf("embedding cache factories not wired") } embedder := deps.Embedder(ec.EmbeddingModel) if embedder == nil { return nil, fmt.Errorf("embedding_model %q not loadable", ec.EmbeddingModel) } storeName := ec.StoreName if storeName == "" { storeName = "router-cache-" + cfg.Name } vstore := deps.VectorStore(storeName) if vstore == nil { return nil, fmt.Errorf("vector store %q not loadable", storeName) } return router.NewEmbeddingCacheClassifier(inner, embedder, vstore, ec.SimilarityThreshold, ec.ConfidenceThreshold), nil } func newDecisionID() string { var b [12]byte _, _ = rand.Read(b[:]) return "rd_" + hex.EncodeToString(b[:]) } // OpenAIProbe extracts a router.Probe from a parsed *schema.OpenAIRequest. // Concatenates message contents (string-form or text blocks of the // structured `[]any` content) so the classifier sees a single corpus // for length and content-shape rules. Image blocks are skipped — a // future multimodal classifier can take a different route. func OpenAIProbe(parsed any) (router.Probe, bool) { req, ok := parsed.(*schema.OpenAIRequest) if !ok || req == nil { return router.Probe{}, false } return OpenAIProbeFromRequest(req), true } // OpenAIProbeFromRequest is the typed counterpart of OpenAIProbe — same // extraction logic, but takes the request struct directly. Realtime and // other non-HTTP callers use it to feed a probe to router.Resolve // without going through an echo.Context first. func OpenAIProbeFromRequest(req *schema.OpenAIRequest) router.Probe { if req == nil { return router.Probe{} } var b strings.Builder for i := range req.Messages { switch ct := req.Messages[i].Content.(type) { case string: b.WriteString(ct) b.WriteByte('\n') case []any: for _, block := range ct { if bm, ok := block.(map[string]any); ok && bm["type"] == "text" { if t, ok := bm["text"].(string); ok { b.WriteString(t) b.WriteByte('\n') } } } } } return router.Probe{Prompt: b.String()} } // AnthropicProbe is the AnthropicRequest analogue of OpenAIProbe. func AnthropicProbe(parsed any) (router.Probe, bool) { req, ok := parsed.(*schema.AnthropicRequest) if !ok || req == nil { return router.Probe{}, false } var b strings.Builder for i := range req.Messages { switch ct := req.Messages[i].Content.(type) { case string: b.WriteString(ct) b.WriteByte('\n') case []any: for _, block := range ct { if bm, ok := block.(map[string]any); ok && bm["type"] == "text" { if t, ok := bm["text"].(string); ok { b.WriteString(t) b.WriteByte('\n') } } } } } return router.Probe{ Prompt: b.String(), }, true }