from fastapi import FastAPI, HTTPException from pydantic import BaseModel, Field import boto3 import os from fastapi.responses import StreamingResponse import time app = FastAPI() os.environ['AWS_REGION'] = 'us-east-1' # Define request body class BedrockRequest(BaseModel): question: str requestSessionId: str = 'user' agentId: str = "1" agentAliasId: str ="1" # AWS Bedrock client setup def get_bedrock_client(): return boto3.client( 'bedrock-agent-runtime', region_name=os.getenv('AWS_REGION', 'us-east-1'), aws_access_key_id=os.getenv('AWS_ACCESS_KEY_ID'), aws_secret_access_key=os.getenv('AWS_SECRET_ACCESS_KEY'), aws_session_token=os.getenv('AWS_SESSION_TOKEN') ) @app.get("/health") def health_check(): return {"message": "OK"} @app.post("/bedrock-agent") def call_bedrock_agent(payload: BedrockRequest): client = get_bedrock_client() try: response_stream = client.invoke_agent( sessionId=payload.requestSessionId, agentId=payload.agentId, agentAliasId=payload.agentAliasId, inputText=payload.question, ) def event_stream(): for event in response_stream["completion"]: if "chunk" in event: # Decode bytes to string if needed chunk_bytes = event["chunk"].get("bytes") if chunk_bytes: yield chunk_bytes.decode("utf-8") else: yield event["chunk"].get("text", "") return StreamingResponse(event_stream(), media_type="text/plain") except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @app.get("/test-stream") def test_stream(): def generator(): for i in range(5): yield f"Chunk {i}\n" time.sleep(1) return StreamingResponse(generator(), media_type="text/plain")