Yohai Rosen
commited on
Commit
·
2a0635e
1
Parent(s):
0f1045d
test
Browse files- config.json +1 -0
- sagemaker_setup.sh +20 -0
- scripts/sagemaker.py +75 -0
config.json
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
{"placeholder": "This is a placeholder config.json"}
|
sagemaker_setup.sh
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
|
| 3 |
+
# Update package list and install ffmpeg
|
| 4 |
+
apt-get update && apt-get install -y ffmpeg
|
| 5 |
+
|
| 6 |
+
# # Ensure the model directory and config.json file exist
|
| 7 |
+
# MODEL_DIR="/opt/ml/model"
|
| 8 |
+
# CONFIG_FILE="${MODEL_DIR}/config.json"
|
| 9 |
+
|
| 10 |
+
# # Ensure the model directory exists
|
| 11 |
+
# mkdir -p ${MODEL_DIR}
|
| 12 |
+
|
| 13 |
+
# # Create a placeholder config.json if it does not exist
|
| 14 |
+
# if [ ! -f ${CONFIG_FILE} ]; then
|
| 15 |
+
# echo "Creating placeholder config.json in ${MODEL_DIR}"
|
| 16 |
+
# echo '{"placeholder": "This is a placeholder config.json"}' > ${CONFIG_FILE}
|
| 17 |
+
# fi
|
| 18 |
+
|
| 19 |
+
# echo "Initialization completed. Model directory contents:"
|
| 20 |
+
# ls -l ${MODEL_DIR}
|
scripts/sagemaker.py
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
import boto3
|
| 4 |
+
import torch
|
| 5 |
+
import argparse
|
| 6 |
+
import time
|
| 7 |
+
from omegaconf import OmegaConf
|
| 8 |
+
|
| 9 |
+
from inference import inference_process # Ensure inference.py is in the same directory or update the import path
|
| 10 |
+
|
| 11 |
+
def download_from_s3(s3_path, local_path):
|
| 12 |
+
s3 = boto3.client('s3')
|
| 13 |
+
bucket, key = s3_path.replace("s3://", "").split("/", 1)
|
| 14 |
+
s3.download_file(bucket, key, local_path)
|
| 15 |
+
|
| 16 |
+
def upload_to_s3(local_path, s3_path):
|
| 17 |
+
s3 = boto3.client('s3')
|
| 18 |
+
bucket, key = s3_path.replace("s3://", "").split("/", 1)
|
| 19 |
+
s3.upload_file(local_path, bucket, key)
|
| 20 |
+
|
| 21 |
+
def model_fn(model_dir):
|
| 22 |
+
# config_path = os.path.join(model_dir, 'config.json')
|
| 23 |
+
|
| 24 |
+
# # Create a placeholder config.json if it does not exist
|
| 25 |
+
# if not os.path.exists(config_path):
|
| 26 |
+
# print(f"config.json not found in {model_dir}. Creating a placeholder config.json.")
|
| 27 |
+
# config_content = {
|
| 28 |
+
# "placeholder": "This is a placeholder config.json"
|
| 29 |
+
# }
|
| 30 |
+
# with open(config_path, 'w') as config_file:
|
| 31 |
+
# json.dump(config_content, config_file)
|
| 32 |
+
|
| 33 |
+
return model_dir
|
| 34 |
+
|
| 35 |
+
def input_fn(request_body, content_type='application/json'):
|
| 36 |
+
if content_type == 'application/json':
|
| 37 |
+
input_data = json.loads(request_body)
|
| 38 |
+
|
| 39 |
+
# Download source_image and driving_audio from S3 if necessary
|
| 40 |
+
source_image_path = input_data['source_image']
|
| 41 |
+
driving_audio_path = input_data['driving_audio']
|
| 42 |
+
|
| 43 |
+
local_source_image = "/opt/ml/input/data/source_image.jpg"
|
| 44 |
+
local_driving_audio = "/opt/ml/input/data/driving_audio.wav"
|
| 45 |
+
|
| 46 |
+
if source_image_path.startswith("s3://"):
|
| 47 |
+
download_from_s3(source_image_path, local_source_image)
|
| 48 |
+
input_data['source_image'] = local_source_image
|
| 49 |
+
if driving_audio_path.startswith("s3://"):
|
| 50 |
+
download_from_s3(driving_audio_path, local_driving_audio)
|
| 51 |
+
input_data['driving_audio'] = local_driving_audio
|
| 52 |
+
|
| 53 |
+
args = argparse.Namespace(**input_data.get('config', {}))
|
| 54 |
+
s3_output = input_data.get('output', None)
|
| 55 |
+
|
| 56 |
+
return args, s3_output
|
| 57 |
+
else:
|
| 58 |
+
raise ValueError(f"Unsupported content type: {content_type}")
|
| 59 |
+
|
| 60 |
+
def predict_fn(input_data, model):
|
| 61 |
+
args, s3_output = input_data
|
| 62 |
+
|
| 63 |
+
# Call the inference process
|
| 64 |
+
inference_process(args)
|
| 65 |
+
|
| 66 |
+
return '.cache/output.mp4', s3_output
|
| 67 |
+
|
| 68 |
+
def output_fn(prediction, content_type='application/json'):
|
| 69 |
+
local_output, s3_output = prediction
|
| 70 |
+
|
| 71 |
+
# Wait for the output file to be created and upload it to S3
|
| 72 |
+
while not os.path.exists(local_output):
|
| 73 |
+
time.sleep(1)
|
| 74 |
+
|
| 75 |
+
return json.dumps({'status': 'completed', 's3_output': s3_output})
|