I'm using tensorflow (Distilbert) to classify text. I have used tflite_flutter package to run text classification using Distilbert to classify topic from text. The training model is shown below using:
dbert_tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
dbert_model = TFDistilBertModel.from_pretrained('distilbert-base-uncased')
max_len=32
input_ids=[]
attention_masks=[]
def read_data():
test_csv = pd.read_csv('datasets/cleaned_test_data.csv')
train_csv = pd.read_csv('datasets/clean_train.csv')
test_csv = test_csv.drop(test_csv.index[0])
return train_csv,test_csv
df_train,df_test = read_data()
df_balanced = df_train[df_train['class']==1].sample(2000)
for index in range(2,11):
df_balanced = pd.concat([df_balanced,df_train[df_train['class']==index].sample(2000)])
x_train = df_balanced['text']
labels = df_balanced['class']
for sent in x_train:
dbert_inps=dbert_tokenizer.encode_plus(sent,add_special_tokens = True,max_length =max_len,pad_to_max_length = True,return_attention_mask = True,truncation=True)
input_ids.append(dbert_inps['input_ids'])
attention_masks.append(dbert_inps['attention_mask'])
input_ids=np.asarray(input_ids)
attention_masks=np.array(attention_masks)
labels=np.array(labels)
train_inp,val_inp,train_label,val_label,train_mask,val_mask=train_test_split(input_ids,labels,attention_masks,test_size=0.2)
def create_model():
inps = Input(shape = (max_len,), dtype='int64')
masks= Input(shape = (max_len,), dtype='int64')
dbert_layer = dbert_model(inps, attention_mask=masks)[0][:,0,:]
dense = Dense(512,activation='relu',kernel_regularizer=regularizers.l2(0.01))(dbert_layer)
dropout= Dropout(0.5)(dense)
pred = Dense(11, activation='softmax',kernel_regularizer=regularizers.l2(0.01))(dropout)
model = tf.keras.Model(inputs=[inps,masks], outputs=pred)
print(model.summary())
return model
log_dir='dbert_model_new'
model_save_path='./dbert_model.h5'
callbacks = [tf.keras.callbacks.ModelCheckpoint(filepath=model_save_path,save_weights_only=True,monitor='val_loss',mode='min',save_best_only=True),keras.callbacks.TensorBoard(log_dir=log_dir)]
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
metric = tf.keras.metrics.SparseCategoricalAccuracy('accuracy')
optimizer = tf.keras.optimizers.Adam(learning_rate=3e-5)
callbacks= [tf.keras.callbacks.ModelCheckpoint(filepath=model_save_path,save_weights_only=True,monitor='val_loss',mode='min',save_best_only=True),keras.callbacks.TensorBoard(log_dir=log_dir)]
model.compile(loss=loss,optimizer=optimizer, metrics=[metric])
model.fit([train_inp,train_mask],train_label,batch_size=16,epochs=5,validation_data=([val_inp,val_mask],val_label),callbacks=callbacks)
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
optimizer = tf.keras.optimizers.Adam(learning_rate=3e-5)
metric = tf.keras.metrics.SparseCategoricalAccuracy('accuracy')
trained_model = create_model()
trained_model.compile(loss=loss,optimizer=optimizer, metrics=[metric])
trained_model.load_weights(model_save_path)
converter = tf.lite.TFLiteConverter.from_keras_model(trained_model)
tflite_model = converter.convert()
open("distilbert_slim_model.tflite","wb").write(tflite_model)
The code above is where the model was trained, and it works perfectly fine when running on python. The model is then converted to tflite for classifying text in flutter, which will be used to predict topics based on the given input. Although the format and the type satisfied the requirements from the input tensors, the output tensor always gives the same result with different inputs. Here is the input tensors required formats in flutter:
The input tensor:
[Tensor{_tensor: Pointer: address=0x7c028b8522c0, name: serving_default_input_1:0, type: int64, shape: [1, 32], data: 256}, Tensor{_tensor: Pointer: address=0x7c028b852330, name: serving_default_input_2:0, type: int64, shape: [1, 32], data: 256}]
The output tensor:
Tensor{_tensor: Pointer: address=0x7c028b8658f0, name: StatefulPartitionedCall:0, type: float32, shape: [1, 11], data: 44}
The code in flutter:
String classifyText({required String rawText}) {
inputId = tokenizeInputText(rawText);
Map category = {
1: 'Society & Culture',
2: 'Science & Mathematics',
3: 'Health',
4: 'Education & Reference',
5: 'Computers & Internet',
6: 'Sports',
7: 'Business & Finance',
8: 'Entertainment & Music',
9: 'Family & Relationships',
10: 'Politics & Government'
};
List<List<double>> output = [[]];
for (var i = 0; i < 11; i++) {
output[0].add(0.0);
}
_interpreter.run(inputId, output);
final maximum = output[0].reduce(
(curr, next) => (curr as double) > (next as double) ? curr : next);
final string =
'$rawText\n$inputId\noutput: $output\nhighest: $maximum\nindex: ${output[0].indexOf(maximum)}\ncategory: ${category[output[0].indexOf(maximum)]}';
return string;
}
The tokenizeInputText returns the same format and type as the tflite input in python below. I have tried implementing it in python using tensorflow module that gives different output based on the inputs. Here is the code for tflite in python:
dbert_tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
interpreter = tf.lite.Interpreter(model_path="distilbert_slim_model.tflite")
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
string = input()
input_list = []
mask_list=[]
dbert_inps=dbert_tokenizer.encode_plus(string,add_special_tokens = True,max_length =256,pad_to_max_length = True,return_attention_mask = True,truncation=True)
input_list.append(dbert_inps['input_ids'])
mask_list.append(dbert_inps['attention_mask'])
input_id = np.array(input_list,dtype=np.int64)
mask = np.array(mask_list,dtype=np.int64)
input_shape = input_details[0]['shape']
interpreter.set_tensor(input_details[0]['index'], input_id)
interpreter.set_tensor(input_details[1]['index'], mask)
interpreter.invoke()
output_data = interpreter.get_tensor(output_details[0]['index'])
print(np.argmax(output_data))
the code above gives different output with different inputs. Any help will be appreciated. Thanks in advance
I have tried many possible methods to achieve the solution to this problem by using the package in tflite_flutter package and the problem still stays there. I tried to check if something wrong with the model using python but it works perfectly fine and gives desired results.