A_ = torch.sigmoid(torch.matmul(x, x.t()))
x is the feature of tens of thousands of nodes, the shape is 700,000*8, 8 is the number of features extracted from each node. Calculation requires several t of memory. How to reduce memory overhead?
I've tried precision halving and chunked calculations but still can't get down to a low enough level.