corderBackend / main.py
pius-code's picture
Update main.py
22e08be verified
from fastapi import FastAPI
from pydantic import BaseModel
from transformers import pipeline, AutoTokenizer,T5ForConditionalGeneration
tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-base")
model = T5ForConditionalGeneration.from_pretrained("google-t5/t5-base")
app = FastAPI()
summarizer = pipeline("summarization", model="facebook/bart-large-cnn")
class TextInput(BaseModel):
text: str
@app.get("/")
async def root():
return {"message": "Welcome to the Text Summarization API!"}
@app.post("/summarize")
async def summarize_text(input: TextInput):
# Count approximate number of words (could be improved with tokenizer count)
word_count = len(input.text.split())
# Set dynamic parameters based on input length
if word_count < 50:
max_length = max(10, word_count // 2) # Half the original length, minimum 10
min_length = max(3, word_count // 4) # Quarter the original length, minimum 3
elif word_count < 200:
max_length = max(50, word_count // 3)
min_length = max(15, word_count // 6)
else:
max_length = max(100, word_count // 4)
min_length = max(30, word_count // 8)
# Prevent max_length from being too large (BART has token limits)
max_length = min(max_length, 1024)
# Generate summary with dynamic parameters
summary = summarizer(
input.text,
max_length=max_length,
min_length=min_length,
do_sample=True,
temperature=0.7,
num_beams=4
)
return {
"summary": summary[0]["summary_text"]
}
@app.post("/translateFrench")
async def translate(input: TextInput):
# Step 1: Prefix the task for the model
prefixed_text = "translate English to French: " + input.text
# Step 2: Tokenize the input
inputs = tokenizer(prefixed_text, return_tensors="pt", truncation=True)
# Step 3: Adjust generation parameters
input_length = inputs.input_ids.shape[1]
max_length = min(512, input_length * 2) # 2x input length but not more than 512
min_length = int(input_length * 1.1) # at least 10% longer than input
# Step 4: Generate translation
outputs = model.generate(
**inputs,
max_length=max_length,
min_length=min_length,
num_beams=5,
length_penalty=1.2,
early_stopping=True,
no_repeat_ngram_size=2
)
# Step 5: Decode result
translated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
return {"translated_text": translated_text}