stefan-it commited on
Commit
a78aa9b
·
verified ·
1 Parent(s): 6ca249a

docs: adjust demo

Browse files
Files changed (1) hide show
  1. README.md +21 -5
README.md CHANGED
@@ -70,11 +70,27 @@ tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=False, rev
70
  model = AutoModelForCausalLM.from_pretrained(model_id, trust_remote_code=False, dtype=torch.bfloat16, revision=revision).to(device)
71
  model.eval()
72
 
73
-
74
- prompt = "Die Altstadt von München "
75
- generator = pipeline('text-generation', model=model, tokenizer=tokenizer, device=device, max_new_tokens=max_new_tokens)
76
- outputs = generator(prompt)
77
- print(outputs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
  ```
79
 
80
  ## License
 
70
  model = AutoModelForCausalLM.from_pretrained(model_id, trust_remote_code=False, dtype=torch.bfloat16, revision=revision).to(device)
71
  model.eval()
72
 
73
+ conversation = [
74
+ {"role": "user", "content": "What is the capital of France?"},
75
+ ]
76
+
77
+ inputs = tokenizer.apply_chat_template(
78
+ conversation,
79
+ add_generation_prompt=True,
80
+ tokenize=True,
81
+ return_dict=True,
82
+ return_tensors="pt"
83
+ ).to(device)
84
+
85
+ with torch.no_grad():
86
+ outputs = model.generate(
87
+ **inputs, # Unpack the dictionary
88
+ max_new_tokens=args.max_new_tokens,
89
+ )
90
+
91
+ # Decode only the generated tokens (excluding the input prompt)
92
+ generated_tokens = outputs[0, inputs["input_ids"].shape[1]:]
93
+ print(tokenizer.decode(generated_tokens, skip_special_tokens=True))
94
  ```
95
 
96
  ## License