Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/defect fixes #345

Merged
merged 14 commits into from
Mar 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def send_job_status(variables):
auth=aws_auth_appsync,
timeout=10
)
logger.info('res :: {}',responseJobstatus)
#logger.info('res :: {}',responseJobstatus)

def get_presigned_url(bucket,key) -> str:
try:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
from langchain.chains import LLMChain
from .sagemaker_endpoint import MultiModal
from aws_lambda_powertools import Logger, Tracer, Metrics
from .StreamingCallbackHandler import StreamingCallbackHandler
from adapters import registry

from .helper import download_file, load_vector_db_opensearch,send_job_status, JobStatus,get_presigned_url,encode_image_to_base64
Expand Down Expand Up @@ -181,7 +180,7 @@ def process_visual_qa(input_params,status_variables,filename):

qa_model= input_params['qa_model']
qa_modelId=qa_model['modelId']

streaming = qa_model.get("streaming", False)
# default model provider is bedrock and defalut modality is tEXT
modality=qa_model.get("modality", "Text")
model_provider=qa_model.get("provider",Provider.BEDROCK)
Expand All @@ -207,28 +206,20 @@ def process_visual_qa(input_params,status_variables,filename):
if(_qa_llm is not None):
local_file_path= download_file(bucket_name,filename)
base64_images=encode_image_to_base64(local_file_path,filename)
status_variables['answer']= generate_vision_answer_bedrock(_qa_llm,base64_images, qa_modelId,decoded_question)
if(status_variables['answer'] is None):
status_variables['answer'] = JobStatus.ERROR_PREDICTION.status
error = JobStatus.ERROR_PREDICTION.get_message()
status_variables['answer'] = error.decode("utf-8")
status_variables['jobstatus'] = JobStatus.ERROR_PREDICTION.status
else:
status_variables['jobstatus'] = JobStatus.DONE.status
streaming = input_params.get("streaming", False)

generate_vision_answer_bedrock(_qa_llm,base64_images, qa_modelId,decoded_question,status_variables,streaming)
else:
logger.error('Invalid Model , cannot load LLM , returning..')
status_variables['jobstatus'] = JobStatus.ERROR_LOAD_LLM.status
error = JobStatus.ERROR_LOAD_LLM.get_message()
status_variables['answer'] = error.decode("utf-8")
send_job_status(status_variables)
else:
logger.error('Invalid Model provider, cannot load LLM , returning..')
status_variables['jobstatus'] = JobStatus.ERROR_LOAD_LLM.status
error = JobStatus.ERROR_LOAD_LLM.get_message()
status_variables['answer'] = error.decode("utf-8")
send_job_status(status_variables)

send_job_status(status_variables)
return status_variables

def generate_vision_answer_sagemaker(_qa_llm,input_params,decoded_question,status_variables,filename):
Expand Down Expand Up @@ -269,18 +260,9 @@ def generate_vision_answer_sagemaker(_qa_llm,input_params,decoded_question,statu

return status_variables

def generate_vision_answer_bedrock(bedrock_client,base64_images,model_id,decoded_question):
system_prompt=""
# use system prompt for fine tuning the performamce
# system_prompt= """
# You have perfect vision and pay great attention to detail which
# makes you an expert at answering architecture diagram question.
# Answer question in <question></question> tags. Before answer,
# think step by step in <thinking> tags and analyze every part of the diagram.
# """
#Create a prompt with the question
prompt =f"<question>{decoded_question}</question>. Answer must be a numbered list in a small paragraph inside <answer></answer> tag."

def generate_vision_answer_bedrock(bedrock_client,base64_images,model_id,
decoded_question,status_variables,streaming):

claude_config = {
'max_tokens': 1000,
'temperature': 0,
krokoko marked this conversation as resolved.
Show resolved Hide resolved
Expand All @@ -302,34 +284,57 @@ def generate_vision_answer_bedrock(bedrock_client,base64_images,model_id,decoded
},
{
"type": "text",
"text": prompt
"text": decoded_question

}
]
}

body=json.dumps({'messages': [messages],**claude_config, "system": system_prompt})
body=json.dumps({'messages': [messages],**claude_config})

try:
response = bedrock_client.invoke_model(
body=body, modelId=model_id, accept="application/json",
contentType="application/json"
)
if streaming:
response = bedrock_client.invoke_model_with_response_stream(
body=body, modelId=model_id, accept="application/json",
contentType="application/json"
)
for event in response.get("body"):
chunk = json.loads(event["chunk"]["bytes"])

if chunk['type'] == 'message_delta':
status_variables['answer']=''
status_variables['jobstatus'] = JobStatus.STREAMING_ENDED.status
send_job_status(status_variables)

if chunk['type'] == 'content_block_delta':
if chunk['delta']['type'] == 'text_delta':
logger.info(chunk['delta']['text'], end="")
chuncked_text=chunk['delta']['text']
llm_answer_bytes = json.dumps(chuncked_text).encode("utf-8")
base64_bytes = base64.b64encode(llm_answer_bytes)
llm_answer_base64_string = base64_bytes.decode("utf-8")
status_variables['answer']=llm_answer_base64_string
status_variables['jobstatus'] = JobStatus.STREAMING_NEW_TOKEN.status
send_job_status(status_variables)


else:
response = bedrock_client.invoke_model(
body=body, modelId=model_id, accept="application/json",
contentType="application/json"
)
response_body = json.loads(response.get('body').read())
logger.info(f'answer is: {response_body}')
output_list = response_body.get("content", [])
for output in output_list:
llm_answer_bytes = json.dumps(output["text"]).encode("utf-8")
base64_bytes = base64.b64encode(llm_answer_bytes)
llm_answer_base64_string = base64_bytes.decode("utf-8")
status_variables['jobstatus'] = JobStatus.DONE.status
status_variables['answer']=llm_answer_base64_string
send_job_status(status_variables)

except Exception as err:
logger.exception(f'Error occurred , Reason :{err}')
return None

response = json.loads(response['body'].read().decode('utf-8'))

formated_response= response['content'][0]['text']
answer = re.findall(r'<answer>(.*?)</answer>', formated_response, re.DOTALL)
formatted_answer=answer[0]
llm_answer_bytes = formatted_answer.encode("utf-8")
print(f' formatted_answer {formatted_answer}')
base64_bytes = base64.b64encode(llm_answer_bytes)
print(f' base64_bytes')
llm_answer_base64_string = base64_bytes.decode("utf-8")

print(f' llm_answer_base64_string {llm_answer_base64_string}')

return llm_answer_base64_string