Come sviluppare un'interfaccia di flusso di token per il vostro LLM con Go, FastAPI e JS

I modelli generativi a volte impiegano un po' di tempo per restituire un risultato, quindi è interessante sfruttare lo streaming di token per vedere i risultati apparire al volo nell'interfaccia utente. Ecco come realizzare un frontend di streaming di testo per il vostro LLM con Go, FastAPI e Javascript.

Sviluppatore su PC

Che cos'è lo streaming di token?

Come promemoria, un token è un'entità unica che può essere una piccola parola, parte di una parola o punteggiatura. In media, 1 token è composto da 4 caratteri, e 100 token equivalgono all'incirca a 75 parole. I modelli di elaborazione del linguaggio naturale devono trasformare il testo in token per poterlo elaborare.

Quando si utilizza un modello AI di generazione del testo (noto anche come modello "generativo"), il tempo di risposta può essere piuttosto elevato, a seconda dell'hardware e delle dimensioni del modello. Ad esempio, nel caso di un modello linguistico di grandi dimensioni (noto anche come "LLM") come LLaMA 30B, distribuito su una GPU NVIDIA A100 in fp16, il modello genera 100 token in circa 3 secondi. Quindi, se si prevede che il modello generativo generi un testo di grandi dimensioni, composto da centinaia o migliaia di parole, la latenza sarà elevata e sarà necessario attendere forse più di 10 secondi per ottenere il risultato completo. più di 10 secondi per ottenere la risposta completa.

Attendere così a lungo per ottenere una risposta può essere un problema dal punto di vista dell'esperienza dell'utente. La soluzione in questo caso è lo streaming di token!

Lo streaming dei token consiste nel generare ogni nuovo token al volo, invece di aspettare che l'intera risposta sia pronta. Questo è ciò che è quello che si può vedere nell'applicazione ChatGPT o, ad esempio, nell'assistente NLP Cloud ChatDolphin. Le parole appaiono non appena vengono generate dal modello. Provate l'assistente ChatDolphin AI qui.

Streaming di token con ChatDolphin su NLP Cloud Streaming di token con l'assistente ChatDolphin su NLP Cloud. Provate qui.

Selezione di un motore di inferenza che supporti il flusso di token

Il primo passo sarà quello di utilizzare un motore di inferenza che supporti lo streaming dei token.

Ecco alcune opzioni che potreste prendere in considerazione:

Ecco un esempio che utilizza il metodo HuggingFace generate():

from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from threading import Thread

tokenizer = AutoTokenizer.from_pretrained("gpt2")
model = AutoModelForCausalLM.from_pretrained("gpt2")
inputs = tokenizer(["An increasing sequence: one,"], return_tensors="pt")
streamer = TextIteratorStreamer(tokenizer)

# Run the generation in a separate thread, so that we can fetch the generated text in a non-blocking way.
generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=20)
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
for new_text in streamer:
    print(new_text)

In questo esempio, generiamo un output con il modello GPT-2 e stampiamo ogni token nella console non appena arriva.

Trasmissione in streaming della risposta con FastAPI

Ora che si è scelto un motore di inferenza, è necessario servire il modello e restituire i token in streaming.

Il modello verrà probabilmente eseguito in un ambiente Python, quindi sarà necessario un server Python per restituire i token e renderli disponibili tramite un'API HTTP. FastAPI è diventata una scelta di fatto per queste situazioni.

Qui usiamo Uvicorn e lo StreamingResponse di FastAPI per servire ogni token non appena viene generato. Ecco un esempio:

from fastapi import FastAPI
from fastapi.responses import StreamingResponse
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from threading import Thread

model = AutoModelForCausalLM.from_pretrained("gpt2")
tokenizer = AutoTokenizer.from_pretrained("gpt2")

app = FastAPI()

async def generate():
    inputs = tokenizer(["An increasing sequence: one,"], return_tensors="pt")
    streamer = TextIteratorStreamer(tokenizer)
    generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=20)
    thread = Thread(target=model.generate, kwargs=generation_kwargs)
    thread.start()
    for new_text in streamer:
        yield new_text

@app.get("/")
async def main():
    return StreamingResponse(generate())

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=8000)

È possibile testare il server di streaming con il seguente comando cURL:

curl -N localhost:8000

Ora abbiamo un modello di intelligenza artificiale funzionante che restituisce correttamente i token in streaming.

Potremmo leggere direttamente questi token in streaming da un'applicazione client in un browser. Ma non lo faremo, per due motivi.

Innanzitutto, è importante disaccoppiare il modello di intelligenza artificiale dal resto dello stack, perché non vogliamo riavviare il modello ogni volta che ogni volta che si apporta una piccola modifica all'API. Si tenga presente che i moderni modelli di intelligenza artificiale generativa sono molto pesanti e spesso richiedono diversi minuti per essere riavviati.

Una seconda ragione è che Python non è necessariamente la scelta migliore quando si tratta di costruire un'applicazione concorrente ad alto throughput come quella che stiamo per fare. Questa scelta può essere discussa e potrebbe anche essere una questione di gusti!

Inoltro dei token attraverso un gateway Go

Come già detto, è importante aggiungere un gateway tra il modello e il cliente finale, e Go è un buon linguaggio di programmazione per un'applicazione di questo tipo. In produzione, si potrebbe anche voler aggiungere un reverse tra il gateway Go e il client finale e un bilanciatore di carico tra il gateway Go e il modello AI, in modo da distribuire il carico su diverse repliche del modello. per distribuire il carico su diverse repliche del modello. Ma questo non rientra nello scopo del nostro articolo!

La nostra applicazione Go si occuperà anche del rendering della pagina HTML finale.

Questa applicazione effettua una richiesta all'applicazione FastAPI, riceve i token in streaming da FastAPI e inoltra ciascun token al frontend utilizzando gli eventi inviati dal server (SSE). token al frontend utilizzando Server Sent Events (SSE). SSE è più semplice di websocket perché è unidirezionale. È è una buona scelta quando si vuole creare un'applicazione che invia informazioni a un client, senza ascoltare una potenziale risposta del client. risposta del cliente.

Ecco il codice Go (il modello HTML/JS/CSS sarà mostrato nella prossima sezione):

package main

import (
    "bufio"
    "fmt"
    "html/template"
    "io"
    "log"
    "net/http"
    "strings"
)

var (
    templates      *template.Template
    streamedTextCh chan string
)

func init() {
    // Parse all templates in the templates folder.
    templates = template.Must(template.ParseGlob("templates/*.html"))

    streamedTextCh = make(chan string)
}

// generateText calls FastAPI and returns every token received on the fly through
// a dedicated channel (streamedTextCh).
// If the EOF character is received from FastAPI, it means that text generation is over.
func generateText(streamedTextCh chan<- string) {
    var buf io.Reader = nil

    req, err := http.NewRequest("GET", "http://127.0.0.1:8000", buf)
    if err != nil {
        log.Fatal(err)
    }

    resp, err := http.DefaultClient.Do(req)
    if err != nil {
        log.Fatal(err)
    }
    defer resp.Body.Close()

    reader := bufio.NewReader(resp.Body)

outerloop:
    for {
        chunk, err := reader.ReadBytes('\x00')
        if err != nil {
            if err == io.EOF {
                break outerloop
            }
            log.Println(err)
            break outerloop
        }

        output := string(chunk)

        streamedTextCh <- output
    }
}

// formatServerSentEvent creates a proper SSE compatible body.
// Server sent events need to follow a specific formatting that
// uses "event:" and "data:" prefixes.
func formatServerSentEvent(event, data string) (string, error) {
    sb := strings.Builder{}

    _, err := sb.WriteString(fmt.Sprintf("event: %s\n", event))
    if err != nil {
        return "", err
    }
    _, err = sb.WriteString(fmt.Sprintf("data: %v\n\n", data))
    if err != nil {
        return "", err
    }

    return sb.String(), nil
}

// generate is an infinite loop that waits for new tokens received 
// from the streamedTextCh. Once a new token is received,
// it is automatically pushed to the frontend as a server sent event. 
func generate(w http.ResponseWriter, r *http.Request) {
    flusher, ok := w.(http.Flusher)
    if !ok {
        http.Error(w, "SSE not supported", http.StatusInternalServerError)
        return
    }

    w.Header().Set("Content-Type", "text/event-stream")

    for text := range streamedTextCh {
        event, err := formatServerSentEvent("streamed-text", text)
        if err != nil {
            http.Error(w, "Cannot format SSE message", http.StatusInternalServerError)
            return
        }

        _, err = fmt.Fprint(w, event)
        if err != nil {
            http.Error(w, "Cannot format SSE message", http.StatusInternalServerError)
            return
        }

        flusher.Flush()
    }
}

// start starts an asynchronous request to the AI engine.
func start(w http.ResponseWriter, r *http.Request) {
    go generateText(streamedTextCh)
}

func home(w http.ResponseWriter, r *http.Request) {
    if err := templates.ExecuteTemplate(w, "home.html", nil); err != nil {
        log.Println(err.Error())
        http.Error(w, "", http.StatusInternalServerError)
        return
    }
}

func main() {
    http.HandleFunc("/generate", generate)
    http.HandleFunc("/start", start).Methods("POST")
    http.HandleFunc("/", home).Methods("GET")

    log.Fatal(http.ListenAndServe(":8000", r))
}                

La nostra pagina "/home" esegue il rendering della pagina HTML/CSS/JS (mostrato più avanti). La pagina "/start" riceve una richiesta POST dall'applicazione JS che attiva una richiesta al nostro modello AI. JS che attiva una richiesta al nostro modello AI. La pagina "/generate" restituisce il risultato all'applicazione JS attraverso gli eventi inviati dal server.

Una volta che la funzione start() riceve una richiesta POST dal frontend, crea automaticamente una goroutine che effettuerà una richiesta alla nostra applicazione FastAPI.

La funzione generateText() chiama FastAPI e restituisce ogni token ricevuto al volo attraverso un canale dedicato (streamedTextCh). Se FastAPI riceve il carattere EOF, significa che la generazione del testo è terminata.

La funzione generate() è un ciclo infinito che attende i nuovi token ricevuti dal canale streamedTextCh. Una volta ricevuto un nuovo token, viene automaticamente inviato al frontend come evento inviato dal server. Gli eventi inviati dal server devono seguire una formattazione specifica che utilizza i prefissi "event:" e "data:". da cui la funzione formatServerSentEvent().

Affinché SSE sia completo, abbiamo bisogno di un client Javascript che sia in grado di ascoltare gli eventi inviati dal server sottoscrivendo la pagina "generate". Si veda la prossima sezione per capire come ottenere questo risultato.

Ricevere i token con Javascript nel browser

Ora è necessario creare una cartella "templates" e aggiungere un file "home.html" al suo interno.

Ecco il contenuto di "home.html":

<!DOCTYPE html>
<html>
<head>
    <meta charset="utf-8">
    <meta name="viewport" content="width=device-width, initial-scale=1">
    <title>Our Streamed Tokens App</title>
</head>
<body>
    <div id="response-section"></div>    
    <form method="POST">
        <button onclick="start()">Start</button>
    </form>
</body>
<script>
    // Disable the default behavior of the HTML form.
    document.querySelector('form').addEventListener('submit', function(e) {
        e.preventDefault()
    })

    // Make a request to the /start to trigger the request to the AI model.
    async function start() {
        try {
            const response = await fetch("/start", {
            method: "POST",
            })
        } catch (error) {
            console.error("Error when starting process:", error)
        }
    }

    // Listen to SSE by subscribing to the /generate page, and
    // put the result in the #response-section div.
    const evtSource = new EventSource("generate")
    evtSource.addEventListener("streamed-text", (event) => {
        document.getElementById('response-section').innerHTML = event.data
    })
</script>
</html>

Come si può vedere, l'ascolto di SSE nel browser è abbastanza semplice.

Per prima cosa è necessario iscriversi al nostro endpoint SSE (la pagina "/generate"). poi è necessario aggiungere un ascoltatore di eventi che leggerà i token in streaming non appena vengono ricevuti.

I browser moderni tentano automaticamente di riconnettersi l'origine dell'evento in caso di problemi di connessione.

Conclusione

Ora sapete come creare una moderna applicazione di IA generativa che trasmetta dinamicamente dinamicamente il testo nel browser, alla maniera di ChatGPT!

Come avete notato, un'applicazione di questo tipo non è necessariamente semplice, poiché sono coinvolti diversi livelli. livelli coinvolti. E naturalmente il codice di cui sopra è eccessivamente semplificato per il bene dell'esempio. esempio.

La sfida principale dello streaming di token riguarda la gestione dei guasti di rete. La maggior parte di questi guasti si verificheranno tra il backend Go e il frontend Javascript. È necessario di esplorare alcune strategie di riconnessione più avanzate e assicurarsi che gli errori vengano correttamente segnalati all'interfaccia utente.

Spero che questo tutorial vi sia stato utile!

Vincent
Consulente per gli sviluppatori presso NLP Cloud