情况:
在已有的基于tensorflow的代码中,加入了文本分析的sentence_transformers,代码在我自己电脑中是可以跑的通的。
但是上传到老师的服务器中,就无法跑了,也不报错,就是卡住。
代码
"""
文件:demo2
时间:2023-08-28
作者:R
描述:测试tensorflow和sentence_transformers
知识点:
"""
import tensorflow as tf
from sentence_transformers import SentenceTransformer
def demo1():
embeddings = tf.constant([
[1, 2, 3, 4],
[11, 12, 13, 14],
[21, 22, 23, 24],
])
print(embeddings.shape)
ids = tf.constant([
[0, 1],
[0, 2]
])
print(ids.shape)
def demo2():
text_model = SentenceTransformer('paraphrase-TinyBERT-L6-v2')
sentence1 = 'Square Root'
embedding1 = text_model.encode(sentence1, convert_to_tensor=True)
print(type(embedding1))
print(embedding1.shape)
print(embedding1)
sentence2 = 'Addition and Subtraction Positive Decimals'
embedding2 = text_model.encode(sentence2, convert_to_tensor=True)
print(embedding2.shape)
print(embedding2)
if __name__ == '__main__':
print("hello,world")
demo1()
demo2()
解决:
开始以为是包的版本问题,各种校对包的版本,发现自己的电脑和老师的服务器都是一致的。
后面注意到这个sentence_transformers是基于pytorch的,而在之前安装虚拟环境的时候,师兄说pytorch的环境和tensorflow的环境最好分开安装,可能会产生冲突,于是就在想是不是这个原因。
后面查询到果然是,tensorflow如果和pytorch一起出现的话,如果先跑tensorflow,那么就会把显卡占满。
所以只需要调整一下顺序,让基于pytorch的sentence_transformers先导入,就可以了。
import tensorflow as tf
from sentence_transformers import SentenceTransformer