와... 드디어 S3에 올려진 모델에 json 뉴스title을 통과시켜서 어떤 카테고리에 저장될지 예측하여 다시 S3에 업로드 시키는 과정을 성공했다!
처음 클라우드를 써보는거라 너무 어려웠는데 성공해서 너무 기쁘당.. 히히
잘 기록해두고 다음에도 모델 사용할때 써야겠다!
import torch
from transformers import BertConfig, BertTokenizer, BertForSequenceClassification
import json
from mymodel import MyModel1
import boto3
def load_model(model_path, number_of_labels, tokenizer_path='bert-base-uncased'):
config = BertConfig.from_pretrained(tokenizer_path, num_labels=number_of_labels)
model = MyModel1(config) # 설정 파일 로드
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
tokenizer = BertTokenizer.from_pretrained(tokenizer_path)
model.eval() # 평가 모드로 설정
return model, tokenizer
def predict(model, tokenizer, input_text):
inputs = tokenizer(input_text, return_tensors="pt", padding=True, truncation=True, max_length=512)
with torch.no_grad():
outputs = model(**inputs)
probabilities = torch.nn.functional.softmax(outputs.logits, dim=1)
return probabilities.numpy()
def lambda_handler(event, context):
s3 = boto3.client('s3',
region_name='-',
aws_access_key_id='-',
aws_secret_access_key='-')
# S3 버킷과 파일 이름
bucket = event['bucket']
input_file_key = event['input_file']
output_file = event['output_file']
# S3에서 입력 파일 읽기
input_obj = s3.get_object(Bucket=bucket, Key=input_file_key)
input_data = json.loads(input_obj['Body'].read().decode('utf-8'))
model, tokenizer = load_model('./model.pth', 4)
# 출력 데이터 준비
outdata = input_data.copy()
for article in outdata['news']:
input_text = article['Title'] # 'Title' 키 사용
prediction = predict(model, tokenizer, input_text)
predicted_category = prediction.argmax(axis=1).item()
article['Category'] = predicted_category
# 결과를 JSON 문자열로 변환
output_json = json.dumps(outdata, ensure_ascii=False)
# S3에 결과 파일 쓰기
s3.put_object(Body=output_json, Bucket=bucket, Key=output_file)
return {
'statusCode': 200,
'body': json.dumps('File uploaded successfully')
}
# 로컬 테스트
if __name__ == "__main__":
test_event = {'bucket': 'aniop2023',
'input_file': 'manual_predicted_news_articles.json',
'output_file': 'manual_predicted_news_articles.json'}
test_context = None
result = lambda_handler(test_event, test_context)
print(result)
728x90
'Study > 소프트웨어공학&비즈니스애널리틱스 (최성철 교수님) 2023-2' 카테고리의 다른 글
클리핑된 뉴스 모델에 통과시켜서 성능확인하기 (0) | 2023.12.10 |
---|---|
뉴스 클리핑 코드 구현 (3) | 2023.12.09 |
드디어 lamda로 구운 나의 커스텀 모델 (2) | 2023.11.21 |
뉴스 클리핑 분류 KoBERT 모델 설명 (0) | 2023.11.11 |
SageMaker Studio에서 ML 분석하기 (0) | 2023.11.11 |