pius-code commited on
Commit
6ee8cf7
·
1 Parent(s): ddd056a

refactor summarize_text and translate endpoint for improved clarity and output length management

Browse files
Files changed (1) hide show
  1. main.py +12 -9
main.py CHANGED
@@ -49,19 +49,22 @@ async def summarize_text(input: TextInput):
49
  )
50
 
51
  return {
52
- "summary": summary[0]["summary_text"],
53
- "parameters_used": {
54
- "input_word_count": word_count,
55
- "max_length": max_length,
56
- "min_length": min_length
57
- }
58
  }
59
 
60
 
61
  @app.post("/translateFrench")
62
  async def translate(input: TextInput):
63
- input.text = "translate English to French: " + input.text
64
- input_ids = tokenizer(input.text, return_tensors="pt").input_ids
65
- output = model.generate(input_ids, max_length=50, num_beams=4, early_stopping=True)
 
 
 
 
 
 
 
 
66
  translated_text = tokenizer.decode(output[0], skip_special_tokens=True)
67
  return {"translated_text": translated_text}
 
49
  )
50
 
51
  return {
52
+ "summary": summary[0]["summary_text"]
 
 
 
 
 
53
  }
54
 
55
 
56
  @app.post("/translateFrench")
57
  async def translate(input: TextInput):
58
+ # Step 1: Add task prefix to guide the T5 model
59
+ prefixed_text = "translate English to French: " + input.text
60
+ # Step 2: Tokenize the input
61
+ input_ids = tokenizer(prefixed_text, return_tensors="pt").input_ids
62
+ # Step 3: Get the input token length properly
63
+ input_length = input_ids.shape[1] # shape is (batch_size, sequence_length)
64
+ # Step 4: Set a reasonable output length (e.g. 20% longer than input)
65
+ max_length = max(10, int(input_length * 1.2))
66
+ # Step 5: Generate translated output
67
+ output = model.generate(input_ids, max_length=max_length, num_beams=4, early_stopping=True)
68
+ # Step 6: Decode the generated tokens
69
  translated_text = tokenizer.decode(output[0], skip_special_tokens=True)
70
  return {"translated_text": translated_text}