When trying to generate adversarial patch on a GPU server I get this error:
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0! (when checking argument for argument tensors in method wrapper_cat)
The code works fine on CPU.
Minimal example:
from art.estimators.object_detection.pytorch_yolo import PyTorchYolo
from art.attacks.evasion import AdversarialPatchPyTorch
from inria_utils import load_inria
from evaluation_metrics import evaluate_patch
import torch
from yolov5.utils.loss import ComputeLoss
import yolov5
def load_model():
class Yolo(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
self.model.hyp = {'box': 0.05,
'obj': 1.0,
'cls': 0.5,
'anchor_t': 4.0,
'cls_pw': 1.0,
'obj_pw': 1.0,
'fl_gamma': 0.0
}
self.compute_loss = ComputeLoss(self.model.model.model)
def forward(self, x, targets=None):
if self.training:
outputs = self.model.model.model(x)
loss, loss_items = self.compute_loss(outputs, targets)
loss_components_dict = {"loss_total": loss}
return loss_components_dict
else:
return self.model(x)
# Set the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = yolov5.load('yolov5s.pt')
model = Yolo(model)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
return PyTorchYolo(model=model,
device_type= 'cuda' if torch.cuda.is_available() else 'cpu',
input_shape=(3, 640, 640),
clip_values=(0, 255),
attack_losses=("loss_total",))
def main():
detector = load_model()
x, _ = load_inria(subset="train", num_samples=8)#x is a numpy array
# I have tried x to tensor and set the device to cuda but ap.generate expects np array
target = detector.predict(x) #its a list
ap = AdversarialPatchPyTorch(
estimator=detector,
rotation_max=8,
scale_min=0.4,
scale_max=1,
learning_rate=1,
batch_size=16,
max_iter=5,
patch_shape=(3, 200, 200),
patch_type='square',
verbose=True,
optimizer='Adam')
ap.generate(x=x, y=target)
if __name__ == "__main__":
main()
I have tried reducing the problem but I cannot reduce it more than the above example.