Generatieve modellen hebben soms wat tijd nodig om een resultaat terug te geven, dus het is interessant om gebruik te maken van token streaming om het resultaat direct in de UI te zien verschijnen. Hier wordt uitgelegd hoe je zo'n tekststreaming frontend voor je LLM kunt maken met Go, FastAPI en Javascript.
Ter herinnering, een token is een unieke entiteit die een klein woord, een deel van een woord of interpunctie kan zijn. Gemiddeld bestaat 1 token uit 4 tekens, en 100 tokens zijn ongeveer gelijk aan 75 woorden. Modellen voor de verwerking van natuurlijke taal moeten uw tekst in tokens omzetten om deze te kunnen verwerken.
Wanneer je een AI-model gebruikt dat tekst genereert (ook bekend als "generatief" model), kan de responstijd vrij hoog zijn, afhankelijk van je hardware en de grootte van je model. In het geval van een groot taalmodel (ook bekend als "LLM") zoals LLaMA 30B, ingezet op een NVIDIA A100 GPU in fp16, genereert het model bijvoorbeeld 100 tokens in ongeveer 3 seconden. Dus als je verwacht dat je generatieve model een groot stuk tekst van honderden of duizenden woorden genereert, dan zal de latentie hoog zijn en moet je misschien meer dan 10 seconden wachten om de volledige tekst te krijgen. meer dan 10 seconden wachten om het volledige antwoord te krijgen.
Zo lang wachten op een antwoord kan een probleem zijn vanuit het oogpunt van gebruikerservaring. De oplossing in dat geval is token streaming!
Token streaming gaat over het genereren van elk nieuw token on the fly in plaats van te wachten tot de hele respons klaar is. Dit is wat je kunt zien op de ChatGPT app, of op de NLP Cloud ChatDolphin assistent bijvoorbeeld. Woorden verschijnen zodra ze zijn gegenereerd door het model. Probeer hier de AI-assistent van ChatDolphin.
Token streaming met de ChatDolphin-assistent op NLP Cloud. Probeer het hier.
De eerste stap is om gebruik te maken van een inferentie-engine die tokenstreaming ondersteunt.
Hier zijn enkele opties die je kunt overwegen:
Hier is een voorbeeld waarbij de methode HuggingFace generate() wordt gebruikt:
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from threading import Thread
tokenizer = AutoTokenizer.from_pretrained("gpt2")
model = AutoModelForCausalLM.from_pretrained("gpt2")
inputs = tokenizer(["An increasing sequence: one,"], return_tensors="pt")
streamer = TextIteratorStreamer(tokenizer)
# Run the generation in a separate thread, so that we can fetch the generated text in a non-blocking way.
generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=20)
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
for new_text in streamer:
print(new_text)
In dit voorbeeld genereren we een uitvoer met het GPT-2 model en drukken we elk token af in de console zodra het binnenkomt.
Nu je een inferentie-engine hebt gekozen, moet je je model serveren en de gestreamde tokens retourneren.
Je model draait waarschijnlijk in een Python-omgeving, dus je hebt een Python-server nodig om de tokens terug te sturen en ze beschikbaar te maken via een HTTP API. FastAPI is een de facto keuze geworden voor dergelijke situaties.
Hier gebruiken we Uvicorn en FastAPI's StreamingResponse om elk token te serveren zodra het is gegenereerd. Hier is een voorbeeld:
from fastapi import FastAPI
from fastapi.responses import StreamingResponse
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from threading import Thread
model = AutoModelForCausalLM.from_pretrained("gpt2")
tokenizer = AutoTokenizer.from_pretrained("gpt2")
app = FastAPI()
async def generate():
inputs = tokenizer(["An increasing sequence: one,"], return_tensors="pt")
streamer = TextIteratorStreamer(tokenizer)
generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=20)
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
for new_text in streamer:
yield new_text
@app.get("/")
async def main():
return StreamingResponse(generate())
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)
Je kunt je streaming server testen met het volgende cURL commando:
curl -N localhost:8000
We hebben nu een werkend AI-model dat op de juiste manier gestreamde tokens terugstuurt.
We zouden deze gestreamde tokens direct kunnen lezen vanuit een clientapplicatie in een browser. Maar dat gaan we niet doen, om 2 redenen.
Ten eerste is het belangrijk om het het AI-model los te koppelen van de rest van de stack omdat we het model niet elke keer opnieuw willen opstarten als we een kleine verandering in de API aanbrengen. Houd in gedachten dat moderne generatieve AI modellen erg zwaar zijn en vaak enkele minuten nodig hebben om opnieuw op te starten.
Een tweede reden is dat Python niet per se de beste keuze is als het aankomt op het bouwen van een hoge doorvoer concurrent applicatie zoals we gaan doen. Over deze keuze kan natuurlijk besproken worden en het kan ook een kwestie van smaak zijn!
Zoals hierboven vermeld, is het belangrijk om een gateway toe te voegen tussen je model en je eindklant, en Go is een goede programmeertaal voor zo'n toepassing. In productie wil je misschien ook een reverse proxy toevoegen tussen de Go-gateway en de uiteindelijke client, en een load balancer tussen je Go-gateway en je AI-model om om de belasting over meerdere replica's van je model te verdelen. Maar dat valt buiten de scope van ons artikel!
Onze Go-toepassing is ook verantwoordelijk voor het renderen van de uiteindelijke HTML-pagina.
Deze applicatie doet een verzoek aan de FastAPI app, ontvangt de gestreamde tokens van FastAPI en stuurt elk token door naar de frontend met behulp van Server Sent Events (SSE). SSE is eenvoudiger dan websockets omdat het unidirectioneel is. Het is een goede keuze als je een applicatie wilt bouwen die informatie naar een client stuurt, zonder te luisteren naar een mogelijk antwoord van de client.
Hier is de Go-code (de HTML/JS/CSS-sjabloon wordt in de volgende sectie getoond):
package main
import (
"bufio"
"fmt"
"html/template"
"io"
"log"
"net/http"
"strings"
)
var (
templates *template.Template
streamedTextCh chan string
)
func init() {
// Parse all templates in the templates folder.
templates = template.Must(template.ParseGlob("templates/*.html"))
streamedTextCh = make(chan string)
}
// generateText calls FastAPI and returns every token received on the fly through
// a dedicated channel (streamedTextCh).
// If the EOF character is received from FastAPI, it means that text generation is over.
func generateText(streamedTextCh chan<- string) {
var buf io.Reader = nil
req, err := http.NewRequest("GET", "http://127.0.0.1:8000", buf)
if err != nil {
log.Fatal(err)
}
resp, err := http.DefaultClient.Do(req)
if err != nil {
log.Fatal(err)
}
defer resp.Body.Close()
reader := bufio.NewReader(resp.Body)
outerloop:
for {
chunk, err := reader.ReadBytes('\x00')
if err != nil {
if err == io.EOF {
break outerloop
}
log.Println(err)
break outerloop
}
output := string(chunk)
streamedTextCh <- output
}
}
// formatServerSentEvent creates a proper SSE compatible body.
// Server sent events need to follow a specific formatting that
// uses "event:" and "data:" prefixes.
func formatServerSentEvent(event, data string) (string, error) {
sb := strings.Builder{}
_, err := sb.WriteString(fmt.Sprintf("event: %s\n", event))
if err != nil {
return "", err
}
_, err = sb.WriteString(fmt.Sprintf("data: %v\n\n", data))
if err != nil {
return "", err
}
return sb.String(), nil
}
// generate is an infinite loop that waits for new tokens received
// from the streamedTextCh. Once a new token is received,
// it is automatically pushed to the frontend as a server sent event.
func generate(w http.ResponseWriter, r *http.Request) {
flusher, ok := w.(http.Flusher)
if !ok {
http.Error(w, "SSE not supported", http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "text/event-stream")
for text := range streamedTextCh {
event, err := formatServerSentEvent("streamed-text", text)
if err != nil {
http.Error(w, "Cannot format SSE message", http.StatusInternalServerError)
return
}
_, err = fmt.Fprint(w, event)
if err != nil {
http.Error(w, "Cannot format SSE message", http.StatusInternalServerError)
return
}
flusher.Flush()
}
}
// start starts an asynchronous request to the AI engine.
func start(w http.ResponseWriter, r *http.Request) {
go generateText(streamedTextCh)
}
func home(w http.ResponseWriter, r *http.Request) {
if err := templates.ExecuteTemplate(w, "home.html", nil); err != nil {
log.Println(err.Error())
http.Error(w, "", http.StatusInternalServerError)
return
}
}
func main() {
http.HandleFunc("/generate", generate)
http.HandleFunc("/start", start).Methods("POST")
http.HandleFunc("/", home).Methods("GET")
log.Fatal(http.ListenAndServe(":8000", r))
}
Onze "/home" pagina rendert de HTML/CSS/JS pagina (wordt later getoond). De "/start"-pagina ontvangt een POST-verzoek van de JS-applicatie die een verzoek naar ons AI-model triggert. En onze "/generate"-pagina stuurt het resultaat terug naar de JS-app via servergestuurde gebeurtenissen.
Zodra de start() functie een POST verzoek ontvangt van de voorkant, maakt het automatisch een goroutine aan die een verzoek zal doen naar onze FastAPI app.
De functie generateText() roept FastAPI aan en retourneert elk token dat on the fly wordt ontvangen via een speciaal kanaal (streamedTextCh). Als het teken EOF wordt ontvangen van FastAPI, betekent dit dat het genereren van tekst is voltooid.
De functie generate() is een oneindige lus die wacht op nieuwe tokens die worden ontvangen van het streamedTextCh kanaal. Zodra een nieuw token is ontvangen, wordt het automatisch naar de voorkant gepushed als een server verzonden gebeurtenis. Door de server verzonden gebeurtenissen moeten een specifieke opmaak volgen die "event:" en "data:" als voorvoegsels gebruikt. voorvoegsels gebruikt, vandaar de functie formatServerSentEvent().
Om SSE compleet te maken, hebben we een Javascript client nodig die kan luisteren naar gebeurtenissen die door de server worden gestuurd door zich te abonneren op de "genereer" pagina. Zie de volgende sectie om te begrijpen hoe we dat kunnen bereiken.
Je moet nu een map "templates" maken en daarin een bestand "home.html" toevoegen.
Hier is de inhoud van "home.html":
<!DOCTYPE html>
<html>
<head>
<meta charset="utf-8">
<meta name="viewport" content="width=device-width, initial-scale=1">
<title>Our Streamed Tokens App</title>
</head>
<body>
<div id="response-section"></div>
<form method="POST">
<button onclick="start()">Start</button>
</form>
</body>
<script>
// Disable the default behavior of the HTML form.
document.querySelector('form').addEventListener('submit', function(e) {
e.preventDefault()
})
// Make a request to the /start to trigger the request to the AI model.
async function start() {
try {
const response = await fetch("/start", {
method: "POST",
})
} catch (error) {
console.error("Error when starting process:", error)
}
}
// Listen to SSE by subscribing to the /generate page, and
// put the result in the #response-section div.
const evtSource = new EventSource("generate")
evtSource.addEventListener("streamed-text", (event) => {
document.getElementById('response-section').innerHTML = event.data
})
</script>
</html>
Zoals je kunt zien, is luisteren naar SSE in de browser vrij eenvoudig.
Eerst moet je je abonneren op ons SSE eindpunt (de "/generate" pagina). Vervolgens moet je een event listener toevoegen die de gestreamde tokens leest zodra ze worden ontvangen.
Moderne browsers proberen automatisch opnieuw verbinding te maken de gebeurtenisbron opnieuw te verbinden in geval van verbindingsproblemen.
Je weet nu hoe je een moderne generatieve AI-toepassing kunt maken die dynamisch tekst streamt in de browser, à la ChatGPT!
Zoals je hebt gemerkt, is zo'n toepassing niet noodzakelijk eenvoudig omdat er verschillende lagen betrokken zijn. En natuurlijk is de bovenstaande code overgesimplificeerd omwille van het voorbeeld.
De grootste uitdaging bij token streaming is het omgaan met netwerkstoringen. De meeste van van deze storingen zullen gebeuren tussen de Go backend en de Javascript frontend. Je zult wat meer geavanceerde verbindingsstrategieën moeten onderzoeken en ervoor zorgen dat fouten goed worden goed gerapporteerd worden aan de UI.
Ik hoop dat je deze tutorial nuttig vond!
Vincent
Developer Adviseur bij NLP Cloud