Spaces:
Runtime error
Runtime error
| import streamlit as st | |
| def get_pipe(): | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| model_name = "heegyu/koalpaca-355m" | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| tokenizer.truncation_side = "right" | |
| model = AutoModelForCausalLM.from_pretrained(model_name) | |
| return model, tokenizer | |
| def get_response(tokenizer, model, context): | |
| context = f"<usr>{context}\n<sys>" | |
| inputs = tokenizer( | |
| context, | |
| truncation=True, | |
| max_length=512, | |
| return_tensors="pt") | |
| generation_args = dict( | |
| max_length=256, | |
| min_length=64, | |
| eos_token_id=2, | |
| do_sample=True, | |
| top_p=1.0, | |
| early_stopping=True | |
| ) | |
| outputs = model.generate(**inputs, **generation_args) | |
| response = tokenizer.decode(outputs[0]) | |
| print(context) | |
| print(response) | |
| response = response[len(context):].replace("</s>", "") | |
| return response | |
| st.title("KoAlpaca-355M") | |
| with st.spinner("loading model..."): | |
| model, tokenizer = get_pipe() | |
| input_ = st.text_area("질문해보세요", value="미국과 중국의 갈등의 원인이 뭐야?") | |
| ok = st.button("물어보기") | |
| if input_ is not None and ok and len(input_) > 0: | |
| with st.spinner("잠시만요"): | |
| response = get_response(tokenizer, model, input_) | |
| st.text("대답") | |
| st.success(response) |