| |
| |
| |
| |
|
|
| """Module to generate OpenELM output given a model and an input prompt.""" |
| import os |
| import logging |
| import time |
| import argparse |
| from typing import Optional, Union |
| import torch |
|
|
| from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
|
|
| def generate( |
| prompt: str, |
| model: Union[str, AutoModelForCausalLM], |
| hf_access_token: str = None, |
| tokenizer: Union[str, AutoTokenizer] = 'meta-llama/Llama-2-7b-hf', |
| device: Optional[str] = None, |
| max_length: int = 1024, |
| assistant_model: Optional[Union[str, AutoModelForCausalLM]] = None, |
| generate_kwargs: Optional[dict] = None, |
| ) -> str: |
| """ Generates output given a prompt. |
| |
| Args: |
| prompt: The string prompt. |
| model: The LLM Model. If a string is passed, it should be the path to |
| the hf converted checkpoint. |
| hf_access_token: Hugging face access token. |
| tokenizer: Tokenizer instance. If model is set as a string path, |
| the tokenizer will be loaded from the checkpoint. |
| device: String representation of device to run the model on. If None |
| and cuda available it would be set to cuda:0 else cpu. |
| max_length: Maximum length of tokens, input prompt + generated tokens. |
| assistant_model: If set, this model will be used for |
| speculative generation. If a string is passed, it should be the |
| path to the hf converted checkpoint. |
| generate_kwargs: Extra kwargs passed to the hf generate function. |
| |
| Returns: |
| output_text: output generated as a string. |
| generation_time: generation time in seconds. |
| |
| Raises: |
| ValueError: If device is set to CUDA but no CUDA device is detected. |
| ValueError: If tokenizer is not set. |
| ValueError: If hf_access_token is not specified. |
| """ |
| if not device: |
| if torch.cuda.is_available() and torch.cuda.device_count(): |
| device = "cuda:0" |
| logging.warning( |
| 'inference device is not set, using cuda:0, %s', |
| torch.cuda.get_device_name(0) |
| ) |
| else: |
| device = 'cpu' |
| logging.warning( |
| ( |
| 'No CUDA device detected, using cpu, ' |
| 'expect slower speeds.' |
| ) |
| ) |
|
|
| if 'cuda' in device and not torch.cuda.is_available(): |
| raise ValueError('CUDA device requested but no CUDA device detected.') |
|
|
| if not tokenizer: |
| raise ValueError('Tokenizer is not set in the generate function.') |
|
|
| if not hf_access_token: |
| raise ValueError(( |
| 'Hugging face access token needs to be specified. ' |
| 'Please refer to https://huggingface.co/docs/hub/security-tokens' |
| ' to obtain one.' |
| ) |
| ) |
|
|
| if isinstance(model, str): |
| checkpoint_path = model |
| model = AutoModelForCausalLM.from_pretrained( |
| checkpoint_path, |
| trust_remote_code=True |
| ) |
| model.to(device).eval() |
| if isinstance(tokenizer, str): |
| tokenizer = AutoTokenizer.from_pretrained( |
| tokenizer, |
| token=hf_access_token, |
| ) |
|
|
| |
| draft_model = None |
| if assistant_model: |
| draft_model = assistant_model |
| if isinstance(assistant_model, str): |
| draft_model = AutoModelForCausalLM.from_pretrained( |
| assistant_model, |
| trust_remote_code=True |
| ) |
| draft_model.to(device).eval() |
|
|
| |
| tokenized_prompt = tokenizer(prompt) |
| tokenized_prompt = torch.tensor( |
| tokenized_prompt['input_ids'], |
| device=device |
| ) |
|
|
| tokenized_prompt = tokenized_prompt.unsqueeze(0) |
|
|
| |
| stime = time.time() |
| output_ids = model.generate( |
| tokenized_prompt, |
| max_length=max_length, |
| pad_token_id=0, |
| assistant_model=draft_model, |
| **(generate_kwargs if generate_kwargs else {}), |
| ) |
| generation_time = time.time() - stime |
|
|
| output_text = tokenizer.decode( |
| output_ids[0].tolist(), |
| skip_special_tokens=True |
| ) |
|
|
| return output_text, generation_time |
|
|
|
|
| def openelm_generate_parser(): |
| """Argument Parser""" |
|
|
| class KwargsParser(argparse.Action): |
| """Parser action class to parse kwargs of form key=value""" |
| def __call__(self, parser, namespace, values, option_string=None): |
| setattr(namespace, self.dest, dict()) |
| for val in values: |
| if '=' not in val: |
| raise ValueError( |
| ( |
| 'Argument parsing error, kwargs are expected in' |
| ' the form of key=value.' |
| ) |
| ) |
| kwarg_k, kwarg_v = val.split('=') |
| try: |
| converted_v = int(kwarg_v) |
| except ValueError: |
| try: |
| converted_v = float(kwarg_v) |
| except ValueError: |
| converted_v = kwarg_v |
| getattr(namespace, self.dest)[kwarg_k] = converted_v |
|
|
| parser = argparse.ArgumentParser('OpenELM Generate Module') |
| parser.add_argument( |
| '--model', |
| dest='model', |
| help='Path to the hf converted model.', |
| required=True, |
| type=str, |
| ) |
| parser.add_argument( |
| '--hf_access_token', |
| dest='hf_access_token', |
| help='Hugging face access token, starting with "hf_".', |
| type=str, |
| ) |
| parser.add_argument( |
| '--prompt', |
| dest='prompt', |
| help='Prompt for LLM call.', |
| default='', |
| type=str, |
| ) |
| parser.add_argument( |
| '--device', |
| dest='device', |
| help='Device used for inference.', |
| type=str, |
| ) |
| parser.add_argument( |
| '--max_length', |
| dest='max_length', |
| help='Maximum length of tokens.', |
| default=256, |
| type=int, |
| ) |
| parser.add_argument( |
| '--assistant_model', |
| dest='assistant_model', |
| help=( |
| ( |
| 'If set, this is used as a draft model ' |
| 'for assisted speculative generation.' |
| ) |
| ), |
| type=str, |
| ) |
| parser.add_argument( |
| '--generate_kwargs', |
| dest='generate_kwargs', |
| help='Additional kwargs passed to the HF generate function.', |
| type=str, |
| nargs='*', |
| action=KwargsParser, |
| ) |
| return parser.parse_args() |
|
|
|
|
| if __name__ == '__main__': |
| args = openelm_generate_parser() |
| prompt = args.prompt |
|
|
| output_text, genertaion_time = generate( |
| prompt=prompt, |
| model=args.model, |
| device=args.device, |
| max_length=args.max_length, |
| assistant_model=args.assistant_model, |
| generate_kwargs=args.generate_kwargs, |
| hf_access_token=args.hf_access_token, |
| ) |
|
|
| print_txt = ( |
| f'\r\n{"=" * os.get_terminal_size().columns}\r\n' |
| '\033[1m Prompt + Generated Output\033[0m\r\n' |
| f'{"-" * os.get_terminal_size().columns}\r\n' |
| f'{output_text}\r\n' |
| f'{"-" * os.get_terminal_size().columns}\r\n' |
| '\r\nGeneration took' |
| f'\033[1m\033[92m {round(genertaion_time, 2)} \033[0m' |
| 'seconds.\r\n' |
| ) |
| print(print_txt) |
|
|