59 lines
1.7 KiB
Python
59 lines
1.7 KiB
Python
from fastapi import FastAPI, HTTPException
|
|
from pydantic import BaseModel, Field
|
|
import boto3
|
|
import os
|
|
from fastapi.responses import StreamingResponse
|
|
import time
|
|
|
|
app = FastAPI()
|
|
|
|
# Define request body
|
|
class BedrockRequest(BaseModel):
|
|
question: str
|
|
requestSessionId: str = 'user'
|
|
agentId: str = "ROJCGWHSC0"
|
|
agentAliasId: str ="TQ8VDTVQII"
|
|
|
|
# AWS Bedrock client setup
|
|
def get_bedrock_client():
|
|
return boto3.client(
|
|
'bedrock-agent-runtime'
|
|
)
|
|
|
|
@app.get("/health")
|
|
def health_check():
|
|
return {"message": "OK"}
|
|
|
|
@app.post("/bedrock-agent")
|
|
def call_bedrock_agent(payload: BedrockRequest):
|
|
client = get_bedrock_client()
|
|
try:
|
|
response_stream = client.invoke_agent(
|
|
sessionId=payload.requestSessionId,
|
|
agentId=payload.agentId,
|
|
agentAliasId=payload.agentAliasId,
|
|
inputText=payload.question,
|
|
)
|
|
|
|
def event_stream():
|
|
for event in response_stream["completion"]:
|
|
if "chunk" in event:
|
|
# Decode bytes to string if needed
|
|
chunk_bytes = event["chunk"].get("bytes")
|
|
if chunk_bytes:
|
|
yield chunk_bytes.decode("utf-8")
|
|
else:
|
|
yield event["chunk"].get("text", "")
|
|
|
|
return StreamingResponse(event_stream(), media_type="text/plain")
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
@app.get("/test-stream")
|
|
def test_stream():
|
|
def generator():
|
|
for i in range(5):
|
|
yield f"Chunk {i}\n"
|
|
time.sleep(1)
|
|
return StreamingResponse(generator(), media_type="text/plain")
|