2025-07-05 21:24:33 +08:00

114 lines
3.4 KiB
Python

from fastapi import FastAPI, HTTPException
from pydantic import BaseModel, Field
import boto3
import os
import asyncio
from fastapi.responses import StreamingResponse
from botocore.exceptions import ClientError
import time
app = FastAPI()
bots_config = {
"dronebot": {
"agent_id": "1",
"agent_alias": "1"
},
"trainerbot": {
"agent_id": "1",
"agent_alias": "1"
},
"checkinbot": {
"agent_id": "1",
"agent_alias": "1"
}
}
ssm = boto3.client("ssm", 'us-east-1')
async def fetch_agent_parameters():
while True:
for i in ['trainerbot', 'dronebot', 'checkinbot']:
bot_key = i
try:
alias_param = f"/{bot_key}/agentalias"
response = ssm.get_parameter(Name=alias_param, WithDecryption=True)
bots_config[bot_key]["agent_alias"] = response['Parameter']['Value']
except ClientError as e:
print(f"Error fetching {alias_param}: {e}")
bots_config[bot_key]["agent_alias"] = "1" # fallback
try:
id_param = f"/{bot_key}/agentid"
response = ssm.get_parameter(Name=id_param, WithDecryption=True)
bots_config[bot_key]["agent_id"] = response['Parameter']['Value']
except ClientError as e:
print(f"Error fetching {id_param}: {e}")
bots_config[bot_key]["agent_id"] = "1" # fallback
print("Updated from SSM —")
for bot_key, config in bots_config.items():
print(f"{bot_key.upper()} → ID: {config['agent_id']}, Alias: {config['agent_alias']}")
await asyncio.sleep(60)
@app.on_event("startup")
async def startup_event():
asyncio.create_task(fetch_agent_parameters())
# Define request body
class BedrockRequest(BaseModel):
question: str
requestSessionId: str = 'user'
bot: str = 'dronebot'
# AWS Bedrock client setup
def get_bedrock_client():
return boto3.client(
'bedrock-agent-runtime',
region_name='us-east-1'
)
@app.get("/health")
def health_check():
return {"message": "OK"}
@app.post("/bedrock-agent")
def call_bedrock_agent(payload: BedrockRequest):
bot_alias = bots_config[payload.bot]['agent_alias']
bot_id = bots_config[payload.bot]['agent_id']
client = get_bedrock_client()
print(payload)
try:
response_stream = client.invoke_agent(
sessionId=payload.requestSessionId + bot_id + bot_alias,
agentId=bot_id,
agentAliasId = bot_alias,
inputText=payload.question,
)
def event_stream():
for event in response_stream["completion"]:
if "chunk" in event:
# Decode bytes to string if needed
chunk_bytes = event["chunk"].get("bytes")
if chunk_bytes:
yield chunk_bytes.decode("utf-8")
else:
yield event["chunk"].get("text", "")
return StreamingResponse(event_stream(), media_type="text/plain")
except Exception as e:
print(e)
raise HTTPException(status_code=500, detail=str(e))
@app.get("/test-stream")
def test_stream():
def generator():
for i in range(5):
yield f"Chunk {i}\n"
time.sleep(1)
return StreamingResponse(generator(), media_type="text/plain")