用CNCLIP模型做内容理解

安装CNCLIP

1
2
3
4
5
conda create -y -n py39cnclip python=3.9
conda activate py39cnclip

pip install -r https://raw.githubusercontent.com/OFA-Sys/Chinese-CLIP/refs/heads/master/requirements.txt # 代理 https://ghproxy.cn/...
pip install cn_clip

模型下载

1
2
3
4
from cn_clip.clip import load_from_name, available_models
print("Available models:", available_models()) # Available models: ['ViT-B-16', 'ViT-L-14', 'ViT-L-14-336', 'ViT-H-14', 'RN50']
model, preprocess = load_from_name("ViT-B-16", device="cpu", download_root='./') # 下载指定模型到当前目录./

应用模型

获取embedding并计算相似度

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import torch 
from PIL import Image

import cn_clip.clip as clip
from cn_clip.clip import load_from_name, available_models

device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = load_from_name("ViT-B-16", device=device, download_root='./') # 若./目录已有对应模型文件,则直接本地加载
model.eval()

image_url = 'https://ghproxy.cn/https://raw.githubusercontent.com/OFA-Sys/Chinese-CLIP/refs/heads/master/examples/pokemon.jpeg'
image = preprocess(Image.open(image_url)).unsqueeze(0).to(device)
text = clip.tokenize(["杰尼龟", "妙蛙种子", "小火龙", "皮卡丘"]).to(device)

with torch.no_grad():
image_features = model.encode_image(image)
text_features = model.encode_text(text)
# 对特征进行归一化,请使用归一化后的图文特征用于下游任务
image_features /= image_features.norm(dim=-1, keepdim=True)
text_features /= text_features.norm(dim=-1, keepdim=True)

logits_per_image, logits_per_text = model.get_similarity(image, text)
probs = logits_per_image.softmax(dim=-1).cpu().numpy()

print("Label probs:", probs) # [[1.268734e-03 5.436878e-02 6.795761e-04 9.436829e-01]]

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import os
import torch
dir_path = 'Downloads/test_batch_image_dl'
file_list = os.listdir(dir_path)
file_fullpath_list = [os.path.join(dir_path, x) for x in file_list]

def get_one_image_feature(image_url):
image = preprocess(Image.open(image_url)).unsqueeze(0).to(device)
with torch.no_grad():
image_features = model.encode_image(image)
image_features /= image_features.norm(dim=-1, keepdim=True)
return image_features # shape=[1, 512]

def get_batch_image_feature(image_url_list, batch_size=16):
assert len(image_url_list) > 0, 'image_url_list不能为空'

image_features = []
for i in range(0, len(image_url_list), batch_size):
images = torch.concat([preprocess(Image.open(image_url)).unsqueeze(0) for image_url in image_url_list[i:(i+batch_size)]]).to(device)
with torch.no_grad():
image_features.append(model.encode_image(images))
image_features = torch.concat(image_features) if len(image_features)>1 else image_features[0]
image_features /= image_features.norm(dim=-1, keepdim=True)
return image_features # shape=[len(image_url_list), 512]

for i in range(10):
tmp_f = get_one_image_feature(file_fullpath_list[i])


def get_batch_text_feature(text_list):
text = clip.tokenize(text_list).to(device)
text_features = model.encode_text(text)
text_features /= text_features.norm(dim=-1, keepdim=True)
return text_features

def get_text_feature(text):
return get_batch_text_feature([text])

参考资料