pius-code commited on
Commit
22e08be
·
verified ·
1 Parent(s): 6ee8cf7

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +23 -10
main.py CHANGED
@@ -55,16 +55,29 @@ async def summarize_text(input: TextInput):
55
 
56
  @app.post("/translateFrench")
57
  async def translate(input: TextInput):
58
- # Step 1: Add task prefix to guide the T5 model
59
  prefixed_text = "translate English to French: " + input.text
 
60
  # Step 2: Tokenize the input
61
- input_ids = tokenizer(prefixed_text, return_tensors="pt").input_ids
62
- # Step 3: Get the input token length properly
63
- input_length = input_ids.shape[1] # shape is (batch_size, sequence_length)
64
- # Step 4: Set a reasonable output length (e.g. 20% longer than input)
65
- max_length = max(10, int(input_length * 1.2))
66
- # Step 5: Generate translated output
67
- output = model.generate(input_ids, max_length=max_length, num_beams=4, early_stopping=True)
68
- # Step 6: Decode the generated tokens
69
- translated_text = tokenizer.decode(output[0], skip_special_tokens=True)
 
 
 
 
 
 
 
 
 
 
 
 
70
  return {"translated_text": translated_text}
 
55
 
56
  @app.post("/translateFrench")
57
  async def translate(input: TextInput):
58
+ # Step 1: Prefix the task for the model
59
  prefixed_text = "translate English to French: " + input.text
60
+
61
  # Step 2: Tokenize the input
62
+ inputs = tokenizer(prefixed_text, return_tensors="pt", truncation=True)
63
+
64
+ # Step 3: Adjust generation parameters
65
+ input_length = inputs.input_ids.shape[1]
66
+ max_length = min(512, input_length * 2) # 2x input length but not more than 512
67
+ min_length = int(input_length * 1.1) # at least 10% longer than input
68
+
69
+ # Step 4: Generate translation
70
+ outputs = model.generate(
71
+ **inputs,
72
+ max_length=max_length,
73
+ min_length=min_length,
74
+ num_beams=5,
75
+ length_penalty=1.2,
76
+ early_stopping=True,
77
+ no_repeat_ngram_size=2
78
+ )
79
+
80
+ # Step 5: Decode result
81
+ translated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
82
+
83
  return {"translated_text": translated_text}