streamitでmachine translationのdemo

以下(model_streamit.py)を作成し、

from transformers import M2M100ForConditionalGeneration, M2M100Tokenizer
import streamlit as st

@st.cache(allow_output_mutation=True)
def load_model_tokenizer():
    model = M2M100ForConditionalGeneration.from_pretrained("facebook/m2m100_418M")
    tokenizer = M2M100Tokenizer.from_pretrained("facebook/m2m100_418M", src_lang="ja", tgt_lang="en")
    
    return model, tokenizer

# adding the text that will show in the text box as default
default_value = "最後の晩餐に出席した。"

def main():
    model, tokenizer = load_model_tokenizer()
    sent = st.text_area("テキストを入力し、Ctrl+Enterで解析結果を表示します。", default_value, height = 275)
    max_length = st.sidebar.slider("Max Length", min_value = 10, max_value=100, value=50)

    encoded_ja = tokenizer(sent, return_tensors="pt")
    generated_tokens = model.generate(
        **encoded_ja,
        forced_bos_token_id=tokenizer.get_lang_id("en"),
        max_length=max_length,
    )
    generated_sequences = tokenizer.decode(generated_tokens[0], skip_special_tokens=True)

    st.write(generated_sequences)


if __name__ == "__main__":
    main()

実行。

streamlit run model_streamit.py