-
Notifications
You must be signed in to change notification settings - Fork 435
Added a remotemodelwrapper class #809
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,63 @@ | ||
| """ | ||
| RemoteModelWrapper class | ||
| -------------------------- | ||
|
|
||
| """ | ||
|
|
||
| import requests | ||
| import torch | ||
| import numpy as np | ||
| import transformers | ||
|
|
||
| class RemoteModelWrapper(): | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This should inherit from the ModelWrapper abstract class if you're looking to be using this as a ModelWrapper |
||
| """This model wrapper queries a remote model with a list of text inputs. | ||
|
|
||
| It sends the input to a remote endpoint provided in api_url. | ||
|
|
||
|
|
||
| """ | ||
l3ra marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| def __init__(self, api_url): | ||
| self.api_url = api_url | ||
| self.model = transformers.AutoModelForSequenceClassification.from_pretrained("textattack/bert-base-uncased-imdb") | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why do you set this? The model variable isn't used elsewhere in this class |
||
|
|
||
| def __call__(self, text_input_list): | ||
| predictions = [] | ||
| for text in text_input_list: | ||
| params = dict() | ||
| params["text"] = text | ||
| response = requests.post(self.api_url, params=params, timeout=10) # Use POST with JSON payload | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this kind of request format guaranteed to work for all endpoints? For example, OpenAI requires a specific kind of payload. I might suggest adding a parameter when you initialize the wrapper you accept a lambda as a param to massage the data into a viable payload format |
||
| if response.status_code != 200: | ||
| print(f"Response content: {response.text}") | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Recommend using the package's logger instead of making print statements, especially when you mean to throw an error. This print statement might not even be necessary since you throw the error below anyways |
||
| raise ValueError(f"API call failed with status {response.status_code}") | ||
| result = response.json() | ||
| # Assuming the API returns probabilities for positive and negative | ||
| predictions.append([result["negative"], result["positive"]]) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I see you're making these assumptions, but I'm not sure if this is so common as to be widely applicable to make this a good wrapper function. To alleviate this, you could add another lambda to massage the output |
||
| return torch.tensor(predictions) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Any reason to cast this as a tensor? |
||
|
|
||
| ''' | ||
| Example usage: | ||
|
|
||
| # Define the remote model API endpoint and tokenizer | ||
| api_url = "https://x.com/predict" | ||
|
|
||
| model_wrapper = RemoteModelWrapper(api_url) | ||
|
|
||
| # Build the attack | ||
| attack = textattack.attack_recipes.TextFoolerJin2019.build(model_wrapper) | ||
|
|
||
| # Define dataset and attack arguments | ||
| dataset = textattack.datasets.HuggingFaceDataset("imdb", split="test") | ||
|
|
||
| attack_args = textattack.AttackArgs( | ||
| num_examples=100, | ||
| log_to_csv="/textfooler.csv", | ||
| checkpoint_interval=5, | ||
| checkpoint_dir="checkpoints", | ||
| disable_stdout=True | ||
| ) | ||
|
|
||
| # Run the attack | ||
| attacker = textattack.Attacker(attack, dataset, attack_args) | ||
| attacker.attack_dataset() | ||
|
|
||
| ''' | ||
l3ra marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
NIT: run
make formator theblackformatter, as mentioned in the contribution guidelines for this repo