Why does my resnet-34 model take so much memory?

92 Views Asked by At

I trained a resnet-34 model and arrange it on the flask server for online image classification.However,it takes about 1.5G memory when I open the server and load the model and jumps to about 3.0G when I predict an image by HTTP request,keeping this level whether I predict more images. Even more strangely,I use flask app to arrange online YOLOv5s and it also takes about 3.2G.I can't understand why they take approximate memory,since the number of resnet-34's parameters is much less than YOLOv5's. Is my resnet-34 flask server code wrong?How can I reduce the memory occupation?

Here is my code:

# initialize the flask app
app = flask.Flask(__name__)
app.model = None
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

data_transform = transforms.Compose(
    [transforms.Resize(256),
     transforms.CenterCrop(224),
     transforms.ToTensor(),
     transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])

# read class_indict
json_path = './class_indices.json'
assert os.path.exists(json_path), "file: '{}' dose not exist.".format(json_path)
with open(json_path, "r") as f:
     class_indict = json.load(f)

def load_model():
    """Load the pre-trained model, you can use your model just as easily.
    """
    model = resnet34(num_classes=7).to(device)

    weights_path = './resnet34.pth'
    assert os.path.exists(weights_path), "file: '{}' dose not exist.".format(weights_path)
    model.load_state_dict(torch.load(weights_path, map_location=device))
    model.eval()
    return model


def load_image(path):
    assert os.path.exists(path), "file: '{}' dose not exist.".format(path)
    img = Image.open(path).convert("RGB")
    # [N, C, H, W]
    img = data_transform(img)
    # expand batch dimension
    img = torch.unsqueeze(img, dim=0)
    return img


@app.route("/resnet", methods=['get', 'post'])
def predict():
    if app.model is None:
        app.model = load_model()
    path = unquote(request.args.get("path", ""))
    if path is None:
        return jsonify({'error': 'Path not provided'}), 400
    image = load_image(path)
    with torch.no_grad():
        # predict class
        output = torch.squeeze(app.model(image.to(device))).cpu()
        predict = torch.softmax(output, dim=0)
        predict_cla = torch.argmax(predict).numpy()
    torch.cuda.empty_cache()
    del image
    gc.collect()

    return jsonify({
        'label': str(predict_cla),
        'predicted_class': class_indict[str(predict_cla)],
        'probability': float(predict[predict_cla].numpy())
    })


if __name__ == '__main__':
    print("Loading PyTorch model and Flask starting server ...")
    print("Please wait until server has fully started")
    # start the classification service and wait for request
    app.run(port='5012')

I will appreciate it if you can provide any advice.

1

There are 1 best solutions below

0
On

When you transfer from torch gpu to numpy you should do detach otherwise it cases data leak

predict_cla = torch.argmax(predict).detach().cpu().numpy()