How to reclaim memory allocated in a loop: python 2.7

856 Views Asked by At

I am trying to process a big ~100GB MD simulation trajectory. Following snippet of code is one of the methods of analysis code. I want to process my trajectory in chunks with size affordable with the available memory of the computer.

On memory profiling using memory_profiler, I found that on line 136 memory is allocated but not being free even after deleting the object in line 143. I also tried by replacing line 136 with list comprehension equivalent, but to gain anything out of it. I am not able to spot or think of any reason for such behavior. Hope, experts insight will help me resolve it. Thanks.

I am running this code on Ubuntu-16.04/CentOS 7 with Python 2.7.12 build using GCC 5.4.0.

I understand that without entire code, below output can not be reproduced. So, I feel sorry for not being able to share entire code scattered in different script files.

Line #    Mem usage    Increment   Line Contents
================================================
    33     51.3 MiB      0.0 MiB       @profile
    34                                 def convert_using_pytraj(self, trajIn):
    35     51.3 MiB      0.0 MiB           bonds_list = []
    36     51.3 MiB      0.0 MiB           angles_list = []
    37     51.3 MiB      0.0 MiB           torsions_list = []
    38                                     
    39     51.3 MiB      0.0 MiB           for k in sorted(self.tree.nodes.keys()):
    40     51.3 MiB      0.0 MiB               if self.tree.nodes[k].a2 > 0:
    41     51.3 MiB      0.0 MiB                   bonds_list.append([k-1, self.tree.nodes[k].a2-1])
    42     51.3 MiB      0.0 MiB               if self.tree.nodes[k].a2 > 0 and self.tree.nodes[k].a3 > 0:
    43     51.3 MiB      0.0 MiB                   angles_list.append([k-1, self.tree.nodes[k].a2 -1, self.tree.nodes[k].a3 -1])
    44     51.3 MiB      0.0 MiB               if self.tree.nodes[k].a2 > 0 and self.tree.nodes[k].a3 > 0 and self.tree.nodes[k].a4 > 0:
    45     51.3 MiB      0.0 MiB                   torsions_list.append([k-1, self.tree.nodes[k].a2 -1, self.tree.nodes[k].a3 -1, self.tree.nodes[k].a4 -1])
    46                                     
    47     51.3 MiB      0.0 MiB           n_atom = len(self.inputs['atoms'])
    48     51.3 MiB      0.0 MiB           pseudo_bonds = None 
    49     51.3 MiB      0.0 MiB           if len(self.inputs['pseudo']) % 2 == 0:
    50     51.3 MiB      0.0 MiB               v_tmp = []
    51     51.3 MiB      0.0 MiB               for i in range(0, len(self.inputs['pseudo']), 2):
    52                                             v_tmp.append((self.inputs['pseudo'][i], self.inputs['pseudo'][i+1]))
    53     51.3 MiB      0.0 MiB               if len(v_tmp) > 0:
    54                                             pseudo_bonds = list(v_tmp)
    55     51.4 MiB      0.1 MiB           logger.debug('bond_indices: %s\nangle_indices: %s\n dih_indices%s' % (bonds_list, angles_list, torsions_list))
    56     51.4 MiB      0.0 MiB           logger.debug('pseudo_bonds: %s' % str(pseudo_bonds))
    57     51.4 MiB      0.0 MiB           logger.debug(str((n_atom, n_atom-1, n_atom-2, n_atom-3, self.inputs['roots'])))
    58     51.4 MiB      0.0 MiB           trjs = []
    59     51.4 MiB      0.0 MiB           slices = []
    60                                     
    61     51.4 MiB      0.0 MiB           for tr1 in trajIn:
    62     51.4 MiB      0.0 MiB               trjs.append(tr1[0])
    63     51.4 MiB      0.0 MiB               if len(tr1) == 2:
    64                                             slices.append(tuple([0, tr1[1], 1]))
    65     51.4 MiB      0.0 MiB               elif len(tr1) == 3:
    66                                             slices.append(tuple([tr1[1]-1, tr1[2], 1]))
    67     51.4 MiB      0.0 MiB               elif len(tr1) == 4:
    68     51.4 MiB      0.0 MiB                   slices.append(tuple([tr1[1]-1, tr1[2], tr1[3]]))
    69     51.4 MiB      0.0 MiB           if slices:
    70     51.4 MiB      0.0 MiB               if len(slices) == len(trjs):
    71     54.3 MiB      2.8 MiB                   traj = pt.iterload(trjs, self.inputs['topoFile'], frame_slice=slices)
    72                                         else:
    73                                             raise Exception("Either all trajin should have slices or none")
    74                                     else:
    75                                         traj = pt.iterload(trjs, self.inputs['topoFile'])
    76     54.3 MiB      0.0 MiB           traj_frames_eff = traj.n_frames
    77     54.3 MiB      0.0 MiB           if not self.inputs['blockProcess']:
    78                                         block_size = traj.n_frames
    79                                     else:
    80     54.3 MiB      0.0 MiB               block_size = self.inputs['blockSize']
    81     54.3 MiB      0.0 MiB           logger.debug(str(("Total number of frames: %d" % traj.n_frames)))
    82     90.7 MiB     36.5 MiB           for block_id, block_start in enumerate(range(0, traj_frames_eff, block_size)):
    83     89.9 MiB     -0.8 MiB               if block_start + block_size <= traj_frames_eff:
    84     70.9 MiB    -19.0 MiB                   block_end = block_start + block_size
    85                                         else:
    86     89.9 MiB     19.0 MiB                   block_end = traj_frames_eff
    87     89.9 MiB      0.0 MiB               if block_id != 0:
    88     89.9 MiB      0.0 MiB                   if slices:
    89     89.9 MiB      0.0 MiB                       if len(slices) == len(trjs):
    90     89.9 MiB      0.0 MiB                           traj = pt.iterload(trjs, self.inputs['topoFile'], frame_slice=slices)
    91                                                 else:
    92                                                     raise Exception("Either all trajin should have slices or none")
    93                                             else:
    94                                                 traj = pt.iterload(trjs, self.inputs['topoFile'])       
    95     89.9 MiB      0.0 MiB               if block_end - block_start > 0:
    96     89.9 MiB      0.0 MiB                   logger.debug('Processing %s block %i Frames %i to %i' % (str(trajIn), block_id+1, block_start, block_end))
    97     89.9 MiB      0.0 MiB                   trj_working = traj[range(block_start, block_end)]
    98     90.2 MiB      0.3 MiB                   bonds_val = pt.distance(trj_working, bonds_list, dtype='ndarray')
    99     90.5 MiB      0.2 MiB                   angles_val = pt.angle(trj_working, angles_list, dtype='ndarray')
   100     90.7 MiB      0.2 MiB                   dihedrals_val = pt.dihedral(trj_working, torsions_list, dtype='ndarray')
   101     90.7 MiB      0.0 MiB                   trj_working = None
   102     90.7 MiB      0.0 MiB                   traj = None
   103     90.7 MiB      0.0 MiB                   if not self.inputs['useDegree']:
   104     90.7 MiB      0.0 MiB                       deg2rad = PI / 180.0
   105     90.7 MiB      0.0 MiB                       PI2 = 2 * PI
   106     90.7 MiB      0.0 MiB                       angles_val = angles_val * deg2rad
   107     90.7 MiB      0.0 MiB                       dihedrals_val = dihedrals_val * deg2rad
   108                                                 # move range (-PI, PI) -> (0.0, 2*PI) by adding 2*PI
   109     90.7 MiB      0.0 MiB                       for i in range(angles_val.shape[0]):
   110     90.7 MiB      0.0 MiB                           for j in range(angles_val.shape[1]):
   111     90.7 MiB      0.0 MiB                               if angles_val[i, j] < 0.0:
   112                                                             angles_val[i, j] += PI2
   113     90.7 MiB      0.0 MiB                       for i in range(dihedrals_val.shape[0]):
   114     90.7 MiB      0.0 MiB                           for j in range(dihedrals_val.shape[1]):
   115     90.7 MiB      0.0 MiB                               if dihedrals_val[i, j] < 0.0:
   116     90.7 MiB      0.0 MiB                                   dihedrals_val[i, j] += PI2
   117     90.7 MiB      0.0 MiB                   if self.inputs['usePhase']:
   118                                                 # Substract value of phase angle if phase
   119                                                 # if modified torsion becomes negative add 2*PI [rad] or 180 [deg]
   120                                                 # if modified torsion is positive substract and in deg substract 180
   121     90.7 MiB      0.0 MiB                       modFactor = 360.0 if self.inputs['useDegree'] else PI2
   122     90.7 MiB      0.0 MiB                       for k in range(0, dihedrals_val.shape[0]):                        
   123     90.7 MiB      0.0 MiB                           if (k != self.phase_defn[k + 1] - 1):
   124     90.7 MiB      0.0 MiB                               for fNo in range(dihedrals_val.shape[1]):
   125     90.7 MiB      0.0 MiB                                   dihedrals_val[k, fNo] -= dihedrals_val[self.phase_defn[k + 1] - 1, fNo]
   126     90.7 MiB      0.0 MiB                                   if (dihedrals_val[k, fNo] < 0.0):
   127     90.7 MiB      0.0 MiB                                       dihedrals_val[k, fNo] += modFactor
   128     90.7 MiB      0.0 MiB                                   if (self.inputs['useDegree']):
   129                                                                 dihedrals_val[k, fNo] -= 180.0
   130     90.7 MiB      0.0 MiB                   logger.debug(str((bonds_val.shape, type(bonds_val), bonds_val)))
   131     90.7 MiB      0.0 MiB                   if block_id==0:
   132     71.9 MiB    -18.8 MiB                       logger.debug("trying to create nc file for bonds")
   133     71.9 MiB      0.0 MiB                       trjBAT = trajIO.trjNetCdfBAT(self.trajOutFile, n_atom, n_atom-1,n_atom-2,n_atom-3, self.inputs['roots'], pseudo_bonds=pseudo_bonds)
   134     72.3 MiB      0.4 MiB                       trjBAT.create_dataset("0O.lig.internal from md2accent")
   135                             
   136     90.7 MiB     18.5 MiB                   frames_indices = np.arange(block_start+1, block_end+1, dtype=np.int32)
   137     90.7 MiB      0.0 MiB                   logger.debug("Writing BAT trajectory...\n")
   138     90.7 MiB      0.0 MiB                   t_bonds_val = np.transpose(bonds_val)
   139     90.7 MiB      0.0 MiB                   t_angles_val = np.transpose(angles_val)
   140     90.7 MiB      0.0 MiB                   t_dihedrals_val = np.transpose(dihedrals_val)
   141                                             #frm_o, frm_n = trjBAT.append_frames(frames_indices, t_bonds_val, t_angles_val, t_dihedrals_val)
   142                                             #logger.debug(str("%d frames appended to file successfully\n" % (frm_n - frm_o)))
   143     90.7 MiB      0.0 MiB                   del frames_indices
   144     90.7 MiB      0.0 MiB                   del t_bonds_val
   145     90.7 MiB      0.0 MiB                   del t_angles_val
   146     90.7 MiB      0.0 MiB                   del t_dihedrals_val
   147     90.7 MiB      0.0 MiB                   del bonds_val
   148     90.7 MiB      0.0 MiB                   del angles_val
   149     90.7 MiB      0.0 MiB                   del dihedrals_val
   150     90.7 MiB      0.0 MiB                   gc.collect()
   151                                         else:
   152                                             logger.critical("Exception: There are no frames to process..")
   153                                             raise Exception("There are no frames to process..")
   154     90.7 MiB      0.0 MiB               traj = None
   155     90.7 MiB      0.0 MiB           logger.debug("Processing input trajectory successful...")
   156     90.7 MiB      0.0 MiB           return(True)
2

There are 2 best solutions below

3
On

In my experience, garbage collection with external libraries written in C or similar can be quite hard (numpy in your case).

However, consider using the gc module included in the standard library. (gc = garbage collection)

Then you could try the following:

import gc
...
del frames_indices
gc.collect()

For further debugging look at these:

0
On

To resolve the mentioned issue: I tried to brief the code keeping core structure of the code similar, something as below for easier follow-up:

# In[1]:

import numpy as np
import scipy as sp
import pytraj as pt
import pandas as pnd
from memory_profiler import profile


# In[2]:

def read_tree_as_dataframe(filename, pdbfile, sel="@H=", header=True, skipTopNlines=0):
    """
        returns a pandas dataframe, where columns form BAT residue memberships
        number of hydrogens involved in bond/angle/torsion
        index of bond/angle/torsion (1-based)

        filename: filename from which tree will be read
        pdbfile: filename of corresponding pdb
        header: True::read tree heads from file, False::give heads as V??

    """
    data = {}
    with open(filename, "r") as f:        
        heads = []
        for ri, row in enumerate(f.readlines()):
            if ri < skipTopNlines:
                pass
            elif ri == skipTopNlines:
                if header:
                    for k in row.split():
                        data[k.strip()] = []
                        heads.append(k.strip())
                else:
                    for i, v in enumerate(row.split()):
                        data['V'+str(i+1)] = [v.strip()]
                        heads.append('V'+str(i+1))
            else:
                for i, c in enumerate(row.split()):
                    if heads[i] == 'tor_type':
                        data[heads[i]].append(c.strip())
                    elif heads[i] == 'is_bb':
                        val = True if c.strip().upper() == 'TRUE' else False
                        data[heads[i]].append(c.strip())
                    else:
                        data[heads[i]].append(int(c.strip()))
    data_len = len(data['tor_type'])
    import numpy
    for k in data.keys():
        data[k] = numpy.asarray(data[k])
    data['bnd_idx'] = numpy.zeros(data_len, dtype=numpy.int)
    data['ang_idx'] = numpy.zeros(data_len, dtype=numpy.int)
    data['tor_idx'] = numpy.zeros(data_len, dtype=numpy.int)
    bnd_idx = 0
    ang_idx = 0
    tor_idx = 0
    for i in range(data_len):
        if data['a2'][i] != -1:
            bnd_idx += 1
            data['bnd_idx'][i] = bnd_idx
        if data['a2'][i] != -1 and data['a3'][i] != -1:
            ang_idx += 1
            data['ang_idx'][i] = ang_idx
        if data['a2'][i] != -1 and data['a3'][i] != -1 and data['a4'][i] != -1:
            tor_idx += 1
            data['tor_idx'][i] = tor_idx
    data['bnd_Sel'] = numpy.zeros(data_len, dtype=numpy.int)
    data['ang_Sel'] = numpy.zeros(data_len, dtype=numpy.int)
    data['tor_Sel'] = numpy.zeros(data_len, dtype=numpy.int)

    import pandas as pd
    import pytraj as pt
    df_tree = pd.DataFrame.from_dict(data)
    pdb = pt.load(pdbfile)
    H_indices = pdb.top.select(sel)
    for i in range(df_tree.shape[0]):
        # Get selected atoms # in bond
        if df_tree.loc[i,'a2'] != -1:                           
            if df_tree.loc[i,'a2']-1 in H_indices:
                df_tree.loc[i,'bnd_Sel'] = df_tree.loc[i,'bnd_Sel'] + 1
            if df_tree.loc[i,'a1']-1 in H_indices: 
                df_tree.loc[i,'bnd_Sel'] = df_tree.loc[i,'bnd_Sel'] + 1
        # Get selected atoms # in angle
        if not (df_tree.loc[i,'a2'] == -1 or df_tree.loc[i,'a3'] == -1):                           
            if df_tree.loc[i,'a3']-1 in H_indices:
                df_tree.loc[i,'ang_Sel'] = df_tree.loc[i,'ang_Sel'] + 1
            if df_tree.loc[i,'a1']-1 in H_indices: 
                df_tree.loc[i,'ang_Sel'] = df_tree.loc[i,'ang_Sel'] + 1
        # Get selected atoms # in torsion
        if not (df_tree.loc[i,'a2'] == -1 or df_tree.loc[i,'a3'] == -1 or df_tree.loc[i,'a4'] == -1):                           
            if df_tree.loc[i,'a1']-1 in H_indices:
                df_tree.loc[i,'tor_Sel'] = df_tree.loc[i,'tor_Sel'] + 1
            if df_tree.loc[i,'a2']-1 in H_indices: 
                df_tree.loc[i,'tor_Sel'] = df_tree.loc[i,'tor_Sel'] + 1
            if df_tree.loc[i,'a3']-1 in H_indices: 
                df_tree.loc[i,'tor_Sel'] = df_tree.loc[i,'tor_Sel'] + 1
            if df_tree.loc[i,'a3']-1 in H_indices: 
                df_tree.loc[i,'tor_Sel'] = df_tree.loc[i,'tor_Sel'] + 1

    return(df_tree)


# In[3]:

df_tree = read_tree_as_dataframe('INR/INR.DFS.tree', 'INR/INR.pdb')

def get_bonds(tree):
    bonds_atom_indices  = []
    df1 = tree[['a1', 'a2']]
    for index, row in df1.iterrows():
        if row['a2'] != -1:
            bonds_atom_indices.append([row['a1'] -1, row['a2'] -1])
    #print(df1)
    print(bonds_atom_indices)
    return(bonds_atom_indices)


# In[6]:

bnd_list = get_bonds(df_tree)


# In[7]:

traj = pt.iterload(['INR/0O.lig_b.cr9-20.mdcrd.nc'], 'INR/INR.lig.gas.leap.prmtop', frame_slice=[(0, 3000000)])


# In[8]:
import gc

@profile
def dummy(chunk_sz, bnd_list):
    print("inside dummy..")
    #return np.reshape(np.zeros(chunk_sz * len(bnd_list)), (32, 60000)) 

@profile
def process_traj(traj, bnd_list, chunk_sz=60000):
    chunk_id = 0 
    print(len(bnd_list))
    for chunk in traj.iterchunk(chunk_sz, start=0, stop=-1):            
        chunk_end =  (chunk_id + 1) * chunk_sz if (chunk_id + 1) * chunk_sz < traj.n_frames else traj.n_frames       
        print("block stast, end: ", chunk_id * chunk_sz, chunk_end)
        if chunk_end - chunk_id * chunk_sz > 0:
            #bonds_val = np.reshape(np.zeros(chunk_sz * len(bnd_list)), (32, 60000)) 
            bonds_val = pt.distance(chunk, bnd_list, dtype='ndarray')
            dummy(chunk_sz, bnd_list)
            #bonds_val = dummy(chunk_sz, bnd_list)
            print(bonds_val.shape)
            del bonds_val
        chunk_id += 1
        gc.collect()




# In[10]:
process_traj(traj, bnd_list)


#get_ipython().magic(u'mprun -f dummy ')

and asked for help from the lead author of the library used pytraj. Thanks to him, discussion with him: see here enabled me to spot a bug in the library cython used in pytraj. Where memoryview to cython.array is used and it caused memoryleak with python2.7.13, but not with the python3 as pointed out by Matt Graham here

Thanks to linusg as well i can not up vote his answer, because yet I don't have 15 reputation points here.