2025-05-30 02:47:30 +08:00

65 lines
1.9 KiB
Python

from fastapi import FastAPI, HTTPException
from pydantic import BaseModel, Field
import boto3
import os
from fastapi.responses import StreamingResponse
import time
app = FastAPI()
os.environ['AWS_REGION'] = 'us-east-1'
# Define request body
class BedrockRequest(BaseModel):
question: str
requestSessionId: str = 'user'
agentId: str = "1"
agentAliasId: str ="1"
# AWS Bedrock client setup
def get_bedrock_client():
return boto3.client(
'bedrock-agent-runtime',
region_name=os.getenv('AWS_REGION', 'us-east-1'),
aws_access_key_id=os.getenv('AWS_ACCESS_KEY_ID'),
aws_secret_access_key=os.getenv('AWS_SECRET_ACCESS_KEY'),
aws_session_token=os.getenv('AWS_SESSION_TOKEN')
)
@app.get("/health")
def health_check():
return {"message": "OK"}
@app.post("/bedrock-agent")
def call_bedrock_agent(payload: BedrockRequest):
client = get_bedrock_client()
try:
response_stream = client.invoke_agent(
sessionId=payload.requestSessionId,
agentId=payload.agentId,
agentAliasId=payload.agentAliasId,
inputText=payload.question,
)
def event_stream():
for event in response_stream["completion"]:
if "chunk" in event:
# Decode bytes to string if needed
chunk_bytes = event["chunk"].get("bytes")
if chunk_bytes:
yield chunk_bytes.decode("utf-8")
else:
yield event["chunk"].get("text", "")
return StreamingResponse(event_stream(), media_type="text/plain")
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/test-stream")
def test_stream():
def generator():
for i in range(5):
yield f"Chunk {i}\n"
time.sleep(1)
return StreamingResponse(generator(), media_type="text/plain")