SQLite how to persist custom data type into new table

50 Views Asked by At

I am trying to write a database that includes numpy arrays (using sqlite3 with Python). However after creating an initial table, I want to perform some operations and save the result as a new table. This is all fine, except that once I create a new table the custom data type I register isn't properly carried over into the new table. A minimal example is:

I define some functions to convert numpy arrays to/from binary:

import io
import sqlite3
import numpy as np

def adapt_array(arr: np.ndarray) -> memoryview:
    out = io.BytesIO()
    np.save(out, arr)
    out.seek(0)
    return sqlite3.Binary(out.read())

def convert_array(text: bytes) -> np.ndarray:
    out = io.BytesIO(text)
    out.seek(0)
    return np.load(out)

I register these adapters and connect:

sqlite3.register_adapter(np.ndarray, adapt_array)
sqlite3.register_converter("array", convert_array)
conn = sqlite3.connect("test.db", detect_types=sqlite3.PARSE_DECLTYPES)
cursor = conn.cursor()

and then create an initial table:

embedding = np.random.randn(10, 64)
cursor.execute('create table test_1 (idx integer primary key, embedding array );')
for i, X in enumerate(embedding):
    cursor.execute('insert into test_1 (idx, embedding) values (?,?)', (i, X))

I then create a new table from this first table:

cursor.execute("create table test_2 as select idx, embedding from test_1;")

but now when I do the following:

cursor.execute("select * from test_1")
data_1 = cursor.fetchall()

cursor.execute("select * from test_2")
data_2 = cursor.fetchall()

data_1 has the embedding field returned as a numpy array as expected, whilst data_2 has the embedding field returned as a binary string. So it seems that for whatever reason the array type is not persisted into the new table.

I have tried:

cursor.execute("create table test_2 as select idx, cast(embedding as array) as embedding from test_1;")

but this doesn't work (just sets every embedding value to 0 for some reason). Does anyone know why this happens/how to get around this?

EDIT: current workaround:

cursor.execute("alter table test_2 add column new_embedding array")
cursor.execute("update test_2 set new_embedding = embedding")
cursor.execute("alter table test_2 drop column embedding")

but I hate it...

0

There are 0 best solutions below