Spaces:
Paused
Paused
| import torch | |
| import torch.nn as nn | |
| from starvector.model.models.starvector_base import StarVectorBase | |
| from transformers import AutoProcessor | |
| class StarVectorStarCoder(StarVectorBase): | |
| def __init__(self, config, **kwargs): | |
| super().__init__(config, **kwargs) | |
| self.processor = AutoProcessor.from_pretrained(config._name_or_path) | |
| def _get_svg_transformer(self, config, **kwargs): | |
| from starvector.model.llm.starcoder import StarCoderModel # This uses StarCoder (V1) | |
| return StarCoderModel(config, **kwargs) | |
| def _get_embeddings(self, input_ids): | |
| """V1 specific embedding method""" | |
| return self.svg_transformer.transformer.transformer.wte(input_ids) | |
| def _get_svg_text(self, svg_list): | |
| """V1 specific SVG text preparation""" | |
| return [t + self.svg_transformer.tokenizer.eos_token for t in svg_list] | |