MedGemma không được thiết kế để sử dụng mà không có quá trình xác thực, điều chỉnh và/hoặc sửa đổi có ý nghĩa phù hợp bởi các nhà phát triển cho trường hợp sử dụng cụ thể của họ.
Đầu ra do MedGemma tạo ra không nhằm mục đích trực tiếp thông báo cho việc chẩn đoán lâm sàng, quyết định quản lý bệnh nhân, khuyến nghị điều trị hoặc bất kỳ ứng dụng thực hành lâm sàng trực tiếp nào khác.
Các điểm chuẩn hiệu suất làm nổi bật khả năng cơ bản trên các điểm chuẩn liên quan, nhưng ngay cả đối với các lĩnh vực hình ảnh và văn bản chiếm phần lớn dữ liệu đào tạo, đầu ra mô hình không chính xác vẫn có thể xảy ra.
Tất cả đầu ra từ MedGemma nên được coi là sơ bộ và cần được xác minh độc lập, đối chiếu lâm sàng và điều tra thêm thông qua các phương pháp nghiên cứu và phát triển đã được thiết lập.
Trí tuệ nhân tạo (AI) đang cách mạng hóa ngành chăm sóc sức khỏe, nhưng làm thế nào để bạn lấy một mô hình AI mạnh mẽ, đa mục đích và dạy nó những kỹ năng chuyên môn của một bác sĩ giải phẫu bệnh?
Hành trình từ nguyên mẫu đến sản xuất thường bắt đầu trong một notebook, và đó chính xác là nơi chúng ta sẽ bắt đầu.
Trong hướng dẫn này, chúng ta sẽ thực hiện bước đầu tiên quan trọng.
Chúng ta sẽ đi qua toàn bộ quy trìnhtinh chỉnhbiến thể Gemma 3MedGemma.
MedGemma là họ mô hình mã nguồn mở của Google dành cho cộng đồng y tế, để phân loại hình ảnh mô bệnh học ung thư vú.
Chúng tôi đang sử dụng mô hình MedGemma độ chính xác đầy đủ vì đó là thứ bạn sẽ cần để đạt hiệu suất tối đa cho nhiều tác vụ lâm sàng.
Nếu bạn lo ngại về chi phí tính toán, bạn có thể lượng tử hóa và tinh chỉnh bằng cách sử dụngnotebook tinh chỉnh được cấu hình sẵn của MedGemmathay thế.
Để hoàn thành bước đầu tiên, chúng tôi sẽ sử dụngNotebook Tinh chỉnh.
Notebook cung cấp cho bạn tất cả mã và giải thích từng bước của quy trình, vì vậy nó là môi trường hoàn hảo để thử nghiệm.
Tôi cũng sẽ chia sẻ những hiểu biết quan trọng mà tôi học được trên đường đi, bao gồm một lựa chọn quan trọng về kiểu dữ liệu đã tạo nên sự khác biệt.
Sau khi chúng ta hoàn thiện mô hình trong giai đoạn tạo nguyên mẫu này, chúng ta sẽ sẵn sàng cho bước tiếp theo.
Trong một bài đăng sắp tới, chúng tôi sẽ chỉ cho bạn cách lấy quy trình làm việc chính xác này và chuyển nó sang môi trường sẵn sàng cho sản xuất, có thể mở rộng bằng cách sử dụngCloud Run jobs.
Nội dung chính
Chuẩn bị sân khấu:
Mục tiêu, mô hình và dữ liệu của chúng ta
Trước khi chúng ta đi đến phần mã, hãy chuẩn bị sân khấu.
Mục tiêu của chúng ta là phân loại hình ảnh kính hiển vi của mô vú thành một trong tám loại:
bốn lành tính (không ung thư) và bốn ác tính (ung thư).
Loại phân loại này đại diện cho một trong nhiều nhiệm vụ quan trọng mà các bác sĩ giải phẫu bệnh thực hiện để đưa ra chẩn đoán chính xác, và chúng ta có một bộ công cụ tuyệt vời cho công việc này.
Chúng ta sẽ sử dụngMedGemma, một họ mô hình mã nguồn mở mạnh mẽ từ Google được xây dựng dựa trên cùng nghiên cứu và công nghệ cung cấp năng lượng cho các mô hình Gemini của chúng tôi.
Điều làm cho MedGemma đặc biệt là nó không chỉ là một mô hình chung:
nó đã được điều chỉnh đặc biệt cho lĩnh vực y tế.
Thành phần thị giác của MedGemma,MedSigLIP, đã được đào tạo trước trên một lượng lớn hình ảnh y tế ẩn danh, bao gồm chính xác loại phiến đồ mô bệnh học mà chúng ta đang sử dụng.
Nếu bạn không cần sức mạnh dự đoán của MedGemma, bạn có thể sử dụng riêng MedSigLIP như một lựa chọn tiết kiệm chi phí hơn cho các tác vụ dự đoán như phân loại hình ảnh.
Có nhiềunotebook hướng dẫn MedSigLIPmà bạn có thể sử dụng để tinh chỉnh.
Thành phần ngôn ngữ của MedGemma cũng được đào tạo trên một tập hợp đa dạng các văn bản y tế, làm cho phiên bảngoogle/medgemma-4b-itmà chúng ta đang sử dụng trở nên hoàn hảo để tuân theo các lời nhắc dựa trên văn bản của chúng ta.
Google cung cấp MedGemma như một nền tảng vững chắc, nhưng nó yêu cầu tinh chỉnh cho các trường hợp sử dụng cụ thể — đó chính xác là điều chúng ta sắp làm.
Để đào tạo mô hình của chúng ta, chúng ta sẽ sử dụng bộ dữ liệuPhân loại Hình ảnh Mô bệnh học Ung thư Vú (BreakHis).
Bộ dữ liệu BreakHis là một bộ sưu tập công khai hàng nghìn hình ảnh kính hiển vi của mô khối u vú được thu thập từ 82 bệnh nhân sử dụng các hệ số phóng đại khác nhau (40X, 100X, 200X và 400X).
Bộ dữ liệu có sẵn công khai cho nghiên cứu phi thương mại và được mô tả chi tiết trong bài báo:
F.
A.
Spanhol, L.
S.
Oliveira, C.
Petitjean, and L.
Heudel,A dataset for breast cancer histopathological image classification.1
Xử lý một mô hình4 tỷtham số đòi hỏi một GPU đủ mạnh, vì vậy tôi đã sử dụngNVIDIA A100với40 GBVRAMtrênVertex AI Workbench.
GPU này có sức mạnh cần thiết, và nó cũng cóNVIDIA Tensor Coreshoạt động xuất sắc với các định dạng dữ liệu hiện đại, mà chúng ta sẽ tận dụng để đào tạo nhanh hơn.
Trong một bài đăng sắp tới, chúng tôi sẽ giải thích cách tính VRAM cần thiết cho việc tinh chỉnh của bạn.
Thảm họa float16 của tôi:
Một bài học quan trọng về tính ổn định
Nỗ lực đầu tiên của tôi để tải mô hình đã sử dụng kiểu dữ liệu float16 phổ biến để tiết kiệm bộ nhớ.
Nó thất bại thảm hại.
Đầu ra của mô hình hoàn toàn vô nghĩa, và một kiểm tra gỡ lỗi nhanh chóng tiết lộ rằng mọi giá trị nội bộ đã sụp đổ thànhNaN (Không phải Số).
Thủ phạm là mộttràn số họckinh điển.
Để hiểu tại sao, bạn cần biết sự khác biệt quan trọng giữa các định dạng 16-bit này:
- float16(FP16):
Có một phạm vi số học rất nhỏ.
Nó không thể biểu diễn bất kỳ số nào lớn hơn 65,504.
Trong hàng triệu phép tính trong một bộ biến đổi (transformer), các giá trị trung gian có thể dễ dàng vượt quá giới hạn này, gây ra tràn số tạo ra NaN.
Khi một NaN xuất hiện, nó làm nhiễm bẩn mọi phép tính tiếp theo.
Định dạng này, được phát triển tại Google Brain, thực hiện một sự đánh đổi quan trọng.
Nó hy sinh một chút độ chính xác để duy trìcùng một phạm vi số học khổng lồnhư định dạng float32 32-bit đầy đủ.
Phạm vi lớn của bfloat16 ngăn chặn tràn số, giúp quá trình huấn luyện ổn định.
Giải pháp sửa chữa chỉ là một thay đổi một dòng mã đơn giản, nhưng nó dựa trên khái niệm quan trọng này.
Mã thành công:
# Giải pháp đơn giản, ổn địnhmodel_kwargs=dict(torch_dtype=torch.bfloat16,# Sử dụng bfloat16 cho phạm vi số học rộng của nódevice_map="auto",attn_implementation="sdpa",)model=AutoModelForImageTextToText.from_pretrained(MODEL_ID,**model_kwargs)
Bài học rút ra:Để tinh chỉnh các mô hình lớn, luôn ưu tiênbfloat16vì tính ổn định của nó.
Đó là một thay đổi nhỏ giúp bạn tránh khỏi một thế giới đau đầu liên quan đến NaN.
Hướng dẫn chi tiết mã:
Một chỉ dẫn từng bước
Bây giờ, hãy đi vào phần mã.
Tôi sẽ chia nhỏNotebook Tinh chỉnhcủa mình thành các bước rõ ràng, hợp lý.
Bước 1:
Thiết lập và cài đặt
Đầu tiên, bạn cần cài đặt các thư viện cần thiết từ hệ sinh thái Hugging Face và đăng nhập vào tài khoản của bạn để tải xuống mô hình.
# Cài đặt các gói yêu cầu!pipinstall--upgrade--quiettransformersdatasetsevaluatepefttrlscikit-learnimportosimportreimporttorchimportgcfromdatasetsimportload_dataset,ClassLabelfrompeftimportLoraConfig,PeftModelfromtransformersimportAutoModelForImageTextToText,AutoProcessorfromtrlimportSFTTrainer,SFTConfigimportevaluate
Xác thực Hugging Face và cách tiếp cận được khuyến nghị để xử lý bí mật của bạn
⚠️ Lưu ý bảo mật quan trọng:Bạn không bao giờ nên mã hóa cứng các bí mật như khóa API hoặc mã thông báo trực tiếp vào mã hoặc sổ ghi chép của mình, đặc biệt là trong môi trường sản xuất.
Thực hành này không an toàn và tạo ra rủi ro bảo mật đáng kể.
Trong Vertex AI Workbench, cách tiếp cận an toàn nhất và cấp doanh nghiệp để xử lý bí mật (như mã thông báo Hugging Face của bạn) là sử dụngTrình quản lý bí mậtcủa Google Cloud.
Nếu bạn chỉ đang thử nghiệm và chưa muốn thiết lập Trình quản lý bí mật, bạn có thể sử dụng tiện ích đăng nhập tương tác.
Tiện ích này lưu mã thông báo tạm thời trong hệ thống tệp của phiên bản.
# Xác thực Hugging Face sử dụng tiện ích đăng nhập tương tác: from huggingface_hub import notebook_login notebook_login()
Trong bài đăng sắp tới, nơi chúng tôi chuyển quy trình này sang Cloud Run Jobs, chúng tôi sẽ chỉ cho bạn cách đúng đắn và an toàn để xử lý mã thông báo này bằng cách sử dụng Trình quản lý bí mật.
Bước 2:
Tải và chuẩn bị tập dữ liệu
Tiếp theo, chúng tôi tải xuốngtập dữ liệu BreakHistừ Kaggle bằng thư việnkagglehub.
Tập dữ liệu này bao gồm một tệpFolds.csv, nêu rõ cách dữ liệu được chia cho các thử nghiệm.
Nghiên cứu ban đầu đã sử dụng xác thực chéo 5-fold, nhưng để giữ thời gian huấn luyện ở mức có thể quản lý được cho phần trình diễn này, chúng tôi sẽ tập trung vào Fold 1 và chỉ sử dụng hình ảnh với độ phóng đại 100X.
Bạn có thể khám phá việc sử dụng các fold và độ phóng đại khác cho các thử nghiệm mở rộng hơn.
!pipinstall-qkagglehubimportkagglehubimportosimportpandasaspdfromPILimportImagefromdatasetsimportDataset,ImageasHFImage,Features,ClassLabel# Tải xuống siêu dữ liệu tập dữ liệupath=kagglehub.dataset_download("ambarish/breakhis")print("Đường dẫn đến các tệp tập dữ liệu:",path)folds=pd.read_csv('{}/Folds.csv'.format(path))# Lọc lấy độ phóng đại 100X từ phần gấp đầu tiênfolds_100x=folds[folds['mag']==100]folds_100x=folds_100x[folds_100x['fold']==1]# Lấy các phần tách huấn luyện/kiểm trafolds_100x_test=folds_100x[folds_100x.grp=='test']folds_100x_train=folds_100x[folds_100x.grp=='train']# Định nghĩa đường dẫn cơ sở cho hình ảnhBASE_PATH="/home/jupyter/.cache/kagglehub/datasets/ambarish/breakhis/versions/4/BreaKHis_v1"
Bước 2.1:
Cân bằng tập dữ liệu
Các phần tách huấn luyện và kiểm tra ban đầu cho độ phóng đại 100X cho thấy sự mất cân bằng giữa các lớp lành tính và ác tính.
Để giải quyết vấn đề này, chúng ta sẽ lấy mẫu dưới lớp chiếm đa số trong cả tập huấn luyện và tập kiểm tra để tạo ra các tập dữ liệu cân bằng với phân phối 50/50.
# --- 1. Tạo Tập HUẤN LUYỆN Cân Bằng --- train_benign_df = folds_100x_train[folds_100x_train['filename'].str.contains('benign')] train_malignant_df = folds_100x_train[folds_100x_train['filename'].str.contains('malignant')] min_train_count = min(len(train_benign_df), len(train_malignant_df)) balanced_train_benign = train_benign_df.sample(n=min_train_count, random_state=42) balanced_train_malignant = train_malignant_df.sample(n=min_train_count, random_state=42) balanced_train_df = pd.concat([balanced_train_benign, balanced_train_malignant]) # --- 2. Tạo Tập KIỂM TRA Cân Bằng --- test_benign_df = folds_100x_test[folds_100x_test['filename'].str.contains('benign')] test_malignant_df = folds_100x_test[folds_100x_test['filename'].str.contains('malignant')] min_test_count = min(len(test_benign_df), len(test_malignant_df)) balanced_test_benign = test_benign_df.sample(n=min_test_count, random_state=42) balanced_test_malignant = test_malignant_df.sample(n=min_test_count, random_state=42) balanced_test_df = pd.concat([balanced_test_benign, balanced_test_malignant]) # --- 3. Lấy Danh sách Tên Tệp Cuối Cùng --- train_filenames = balanced_train_df['filename'].values test_filenames = balanced_test_df['filename'].values print(f"Tập Huấn Luyện Cân Bằng: {len(train_filenames)} tệp") print(f"Tập Kiểm Tra Cân Bằng: {len(test_filenames)} tệp")
Bước 2.2:
Tạo một tập dữ liệu Hugging Face
Chúng ta đang chuyển đổi dữ liệu của mình sang định dạngdatasetscủa Hugging Face vì đây là cách dễ nhất để làm việc vớiSFTTrainertừ thư viện Transformers của họ.
Định dạng này được tối ưu hóa để xử lý các tập dữ liệu lớn, đặc biệt là hình ảnh, vì nó có thể tải chúng một cách hiệu quả khi cần.
Và nó cung cấp cho chúng ta các công cụ tiện lợi để tiền xử lý, như áp dụng hàm định dạng của chúng ta cho tất cả các mẫu.
CLASS_NAMES = [ 'benign_adenosis', 'benign_fibroadenoma', 'benign_phyllodes_tumor', 'benign_tubular_adenoma', 'malignant_ductal_carcinoma', 'malignant_lobular_carcinoma', 'malignant_mucinous_carcinoma', 'malignant_papillary_carcinoma' ] def get_label_from_filename(filename): filename = filename.replace('\\', '/').lower() if '/adenosis/' in filename: return 0 if '/fibroadenoma/' in filename: return 1 if '/phyllodes_tumor/' in filename: return 2 if '/tubular_adenoma/' in filename: return 3 if '/ductal_carcinoma/' in filename: return 4 if '/lobular_carcinoma/' in filename: return 5 if '/mucinous_carcinoma/' in filename: return 6 if '/papillary_carcinoma/' in filename: return 7 return -1 train_data_dict = { 'image': [os.path.join(BASE_PATH, f) for f in train_filenames], 'label': [get_label_from_filename(f) for f in train_filenames] } test_data_dict = { 'image': [os.path.join(BASE_PATH, f) for f in test_filenames], 'label': [get_label_from_filename(f) for f in test_filenames] } features = Features({ 'image': HFImage(), 'label': ClassLabel(names=CLASS_NAMES) }) train_dataset = Dataset.from_dict(train_data_dict, features=features).cast_column("image", HFImage()) eval_dataset = Dataset.from_dict(test_data_dict, features=features).cast_column("image", HFImage()) print(train_dataset) print(eval_dataset)
Bước 3:
Kỹ thuật tạo lời nhắc
Bước này là nơi chúng ta nói với mô hình những gì chúng ta muốn nó làm.
Chúng ta tạo ra một lời nhắc rõ ràng, có cấu trúc để hướng dẫn mô hình phân tích một hình ảnh và chỉ trả về con số tương ứng với một lớp.
Lời nhắc này làm cho đầu ra trở nên đơn giản và dễ dàng phân tích.
Sau đó, chúng ta ánh xạ định dạng này trên toàn bộ tập dữ liệu của chúng ta.
# Định nghĩa lời nhắc hướng dẫnPROMPT="""Phân tích hình ảnh mô bệnh học vú này và phân loại nó.
Các lớp (0-7):
0:
u tuyến lành tính 1:
u xơ tuyến lành tính 2:
u lá lành tính 3:
u ống tuyến lành tính 4:
ung thư biểu mô ống ác tính 5:
ung thư biểu mô tiểu thùy ác tính 6:
ung thư biểu mô nhầy ác tính 7:
ung thư biểu mô nhú ác tính Trả lời chỉ bằng số (0-7):"""defformat_data(example):"""Định dạng các mẫu thành tin nhắn kiểu trò chuyện mà MedGemma mong đợi."""example["messages"]=[{"role":"user","content":[{"type":"image"},{"type":"text","text":PROMPT},],},{"role":"assistant","content":[{"type":"text","text":str(example["label"])},],},]returnexample# Áp dụng định dạngformatted_train=train_dataset.map(format_data,batched=False)formatted_eval=eval_dataset.map(format_data,batched=False)print("✓ Dữ liệu đã được định dạng với lời nhắc hướng dẫn")
Bước 4:
Tải mô hình và bộ xử lý
Ở đây, chúng tôi tải mô hình MedGemma và bộ xử lý liên kết của nó.
Bộ xử lý là một công cụ tiện lợi để chuẩn bị cả hình ảnh và văn bản cho mô hình.
Chúng tôi cũng sẽ thực hiện hai lựa chọn tham số chính để tối ưu hiệu suất:
torch_dtype=torch.bfloat16:
Như đã đề cập trước đó, định dạng này đảm bảo tính ổn định số học.attn_implementation="sdpa":Cơ chế chú ý tích vô hướng có tỷ lệlà một cơ chế chú ý được tối ưu hóa cao có sẵn trong PyTorch 2.0.
Hãy coi cơ chế này như việc bảo mô hình sử dụng một động cơ siêu nhanh, được tích hợp sẵn cho phép tính quan trọng nhất của nó.
Nó tăng tốc độ huấn luyện và suy luận, và thậm chí có thể tự động sử dụng các backend nâng cao hơn như FlashAttention nếu phần cứng của bạn hỗ trợ.
MODEL_ID="google/medgemma-4b-it"# Cấu hình mô hìnhmodel_kwargs=dict(torch_dtype=torch.bfloat16,device_map="auto",attn_implementation="sdpa",)model=AutoModelForImageTextToText.from_pretrained(MODEL_ID,**model_kwargs)processor=AutoProcessor.from_pretrained(MODEL_ID)# Đảm bảo căn lề phải cho việc huấn luyệnprocessor.tokenizer.padding_side="right"
Bước 5:
Đánh giá mô hình cơ sở
Trước khi đầu tư thời gian và tài nguyên tính toán để tinh chỉnh, hãy xem mô hình được đào tạo trước tự nó hoạt động như thế nào.
Bước này cung cấp cho chúng ta một đường cơ sở để đo lường sự cải thiện của chúng ta.
# Các hàm hỗ trợ để chạy đánh giáaccuracy_metric=evaluate.load("accuracy")f1_metric=evaluate.load("f1")defcompute_metrics(predictions,references):return{**accuracy_metric.compute(predictions=predictions,references=references),**f1_metric.compute(predictions=predictions,references=references,average="weighted")}defpostprocess_prediction(text):"""Trích xuất chỉ số từ đầu ra văn bản của mô hình."""digit_match=re.search(r'\b([0-7])\b',text.strip())returnint(digit_match.group(1))ifdigit_matchelse-1defbatch_predict(model,processor,prompts,images,batch_size=8,max_new_tokens=40):"""Một hàm để chạy suy luận theo lô."""predictions=[]foriinrange(0,len(prompts),batch_size):batch_texts=prompts[i:i+batch_size]batch_images=[[img]forimginimages[i:i+batch_size]]inputs=processor(text=batch_texts,images=images,padding=True,return_tensors="pt").to("cuda",torch.bfloat16)prompt_lengths=inputs["attention_mask"].sum(dim=1)withtorch.inference_mode():outputs=model.generate(**inputs,max_new_tokens=max_new_tokens,do_sample=False,pad_token_id=processor.tokenizer.pad_token_id)forseq,lengthinzip(outputs,prompt_lengths):generated=processor.decode(seq[length:],skip_special_tokens=True)predictions.append(postprocess_prediction(generated))returnpredictions# Chuẩn bị dữ liệu để đánh giáeval_prompts=[processor.apply_chat_template([msg[0]],add_generation_prompt=True,tokenize=False)formsginformatted_eval["messages"]]eval_images=formatted_eval["image"]eval_labels=formatted_eval["label"]# Chạy đánh giá cơ sởprint("Đang chạy đánh giá cơ sở...")baseline_preds=batch_predict(model,processor,eval_prompts,eval_images)baseline_metrics=compute_metrics(baseline_preds,eval_labels)print(f"\n{'KẾT QUẢ CƠ SỞ':-^80}")print(f"Độ chính xác:{baseline_metrics['accuracy']:.1%}")print(f"Điểm F1:{baseline_metrics['f1']:.3f}")print("-"*80)
Hiệu suất của mô hình cơ sở được đánh giá trên cả phân loại 8 lớp và phân loại nhị phân (lành tính/ác tính):
- Độ chính xác 8 lớp:
32.6% - Điểm F1 8 lớp (có trọng số):
0.241 - Độ chính xác nhị phân:
59.6% - Điểm F1 nhị phân (ác tính):
0.639
Kết quả này cho thấy mô hình hoạt động tốt hơn ngẫu nhiên (12.5%), nhưng vẫn còn nhiều khoảng trống để cải thiện, đặc biệt là trong phân loại chi tiết 8 lớp.
Một chuyển hướng nhanh:
Học ít mẫu so với tinh chỉnh
Trước khi bắt đầu huấn luyện, đáng để hỏi:
liệu tinh chỉnh có phải là cách duy nhất?
Một kỹ thuật phổ biến khác làhọc ít mẫu.
Học ít mẫu giống như đưa cho một học sinh thông minh vài ví dụ về một bài toán mới ngay trước bài kiểm tra.
Bạn không dạy lại đại số cho họ, bạn chỉ cho họ thấy mẫu cụ thể bạn muốn họ làm theo bằng cách cung cấp các ví dụ trực tiếp trong lời nhắc.
Đây là một kỹ thuật mạnh mẽ, đặc biệt khi bạn sử dụng một mô hình đóng qua API mà bạn không thể truy cập các trọng số bên trong.
Vậy tại sao chúng tôi lại chọn tinh chỉnh?
- Chúng tôi có thể lưu trữ mô hình:
Vì MedGemma là một mô hình mở, chúng tôi có quyền truy cập trực tiếp vào kiến trúc của nó.
Quyền truy cập này cho phép chúng tôi thực hiện tinh chỉnh để tạo ra một phiên bản mới, được cập nhật vĩnh viễn của mô hình. - Chúng tôi có một bộ dữ liệu tốt:
Tinh chỉnh cho phép mô hình học các mẫu sâu, cơ bản trong hàng trăm hình ảnh huấn luyện của chúng tôi hiệu quả hơn nhiều so với chỉ cho nó xem vài ví dụ trong lời nhắc.
Nói ngắn gọn, tinh chỉnh tạo ra mộtmô hình chuyên gia thực sựcho nhiệm vụ của chúng tôi, và đó chính xác là điều chúng tôi muốn.
Bước 6:
Cấu hình và chạy tinh chỉnh với LoRA
Đây là phần chính!
Chúng tôi sẽ sử dụngThích ứng Hạng Thấp (LoRA), vốn nhanh hơn và tiết kiệm bộ nhớ hơn nhiều so với tinh chỉnh truyền thống.
LoRA hoạt động bằng cách đóng băng các trọng số mô hình gốc và chỉ huấn luyện một tập nhỏ các trọng số bộ chuyển đổi mới.
Dưới đây là phân tích các lựa chọn tham số của chúng tôi:
r=8:
Hạng của LoRA.
Hạng thấp hơn nghĩa là ít tham số có thể huấn luyện hơn, điều này nhanh hơn nhưng ít biểu đạt hơn.
Hạng cao hơn có nhiều khả năng hơn, nhưng có nguy cơ quá khớp trên một tập dữ liệu nhỏ.
Hạng 8 là điểm khởi đầu tuyệt vời để cân bằng hiệu suất và hiệu quả.lora_alpha=16:
Một hệ số tỷ lệ cho các trọng số LoRA.
Một quy tắc ngón tay cái phổ biến là đặt nó bằng hai lần hạng (2 × r).lora_dropout=0.1:
Một kỹ thuật chính quy hóa.
Nó vô hiệu hóa ngẫu nhiên một số nơ-ron LoRA trong quá trình huấn luyện để ngăn mô hình trở nên quá chuyên biệt và không thể tổng quát hóa.
# LoRA Configurationpeft_config=LoraConfig(r=8,lora_alpha=16,lora_dropout=0.1,bias="none",target_modules="all-linear",task_type="CAUSAL_LM",)# Custom data collator to handle images and textdefcollate_fn(examples):texts,images=[],[]forexampleinexamples:images.append([example["image"]])texts.append(processor.apply_chat_template(example["messages"],add_generation_prompt=False,tokenize=False).strip())batch=processor(text=texts,images=images,return_tensors="pt",padding=True)labels=batch["input_ids"].clone()labels[labels==processor.tokenizer.pad_token_id]=-100image_token_id=processor.tokenizer.convert_tokens_to_ids(processor.tokenizer.special_tokens_map["boi_token"])labels[labels==image_token_id]=-100labels[labels==262144]=-100batch["labels"]=labelsreturnbatch# Training argumentstraining_args=SFTConfig(output_dir="medgemma-breastcancer-finetuned",num_train_epochs=5,per_device_train_batch_size=1,per_device_eval_batch_size=1,gradient_accumulation_steps=8,gradient_checkpointing=True,optim="paged_adamw_8bit",learning_rate=5e-4,lr_scheduler_type="cosine",warmup_ratio=0.03,# Warm up LR for first 3% of trainingmax_grad_norm=0.3,# Clip gradients to prevent instabilitybf16=True,# Use bfloat16 precisionlogging_steps=10,save_strategy="steps",save_steps=100,eval_strategy="epoch",push_to_hub=False,report_to="none",gradient_checkpointing_kwargs={"use_reentrant":False},dataset_kwargs={"skip_prepare_dataset":True},remove_unused_columns=False,label_names=["labels"],)# Initialize and run the trainertrainer=SFTTrainer(model=model,args=training_args,train_dataset=formatted_train,eval_dataset=formatted_eval,peft_config=peft_config,processing_class=processor,data_collator=collate_fn,)print("Starting training...")trainer.train()trainer.save_model()
Quá trình huấn luyện mất khoảng80 phúttrên GPU A100 với VRAM 40 GB.
Kết quả trông rất hứa hẹn, với tổn thất kiểm định giảm dần đều.
Mẹo quan trọng (tiết kiệm thời gian!):
Nếu quá trình huấn luyện của bạn bị gián đoạn vì bất kỳ lý do gì (như sự cố kết nối hoặc vượt quá giới hạn tài nguyên), bạn có thể tiếp tục quá trình huấn luyện từ một checkpoint đã lưu bằng cách sử dụng đối sốresume_from_checkpointtrongtrainer.train().
Checkpoint có thể tiết kiệm thời gian quý giá vì chúng được lưu ở mỗi khoảngsave_stepsnhư đã định nghĩa trongTrainingArguments.
Bước 7:
Kết luận cuối cùng
– đánh giá mô hình đã được tinh chỉnh của chúng ta
Sau khi huấn luyện, đã đến lúc cho sự thật.
Chúng ta sẽ tải trọng số bộ điều hợp LoRA mới, hợp nhất chúng với mô hình cơ sở, và sau đó chạy cùng một đánh giá mà chúng ta đã chạy cho đường cơ sở.
# Xóa bộ nhớ và tải mô hình cuối cùngdelmodeltorch.cuda.empty_cache()gc.collect()# Tải lại mô hình cơ sởbase_model=AutoModelForImageTextToText.from_pretrained(MODEL_ID,torch_dtype=torch.bfloat16,device_map="auto",attn_implementation="sdpa")# Tải bộ điều chỉnh LoRA và hợp nhất chúng thành một mô hình duy nhấtfinetuned_model=PeftModel.from_pretrained(base_model,training_args.output_dir)finetuned_model=finetuned_model.merge_and_unload()# Cấu hình cho quá trình sinh văn bảnfinetuned_model.generation_config.max_new_tokens=50finetuned_model.generation_config.pad_token_id=processor_finetuned.tokenizer.pad_token_idfinetuned_model.config.pad_token_id=processor_finetuned.tokenizer.pad_token_id# Tải bộ xử lý và chạy đánh giáprocessor_finetuned=AutoProcessor.from_pretrained(training_args.output_dir)finetuned_preds=batch_predict(finetuned_model,processor_finetuned,eval_prompts,eval_images,batch_size=4)finetuned_metrics=compute_metrics(finetuned_preds,eval_labels)
Kết quả cuối cùng
Vậy, việc tinh chỉnh đã tác động thế nào đến hiệu suất?
Hãy xem các con số về độ chính xác 8 lớp và F1 tổng quát.
--- 8-Class Classification (0-7) --- Model Accuracy F1 (Weighted) ----------------------------------------------- Baseline 32.6% 0.241 Fine-tuned 87.2% 0.865 ----------------------------------------------- --- Binary (Benign/Malignant) Classification --- Model Accuracy F1 (Malignant) ----------------------------------------------- Baseline 59.6% 0.639 Fine-tuned 99.0% 0.991 -----------------------------------------------
Kết quả thật tuyệt vời!
Sau khi tinh chỉnh, chúng ta thấy một sự cải thiện đáng kể:
- 8 Lớp:
Độ chính xác tăng từ 32.6% lên 87.2% (+54.6%) và F1 từ 0.241 lên 0.865. - Nhị phân:
Độ chính xác tăng từ 59.6% lên 99.0% (+39.4%) và F1 từ 0.639 lên 0.991.
Dự án này cho thấy sức mạnh đáng kinh ngạc của việc tinh chỉnh cácmô hình nền tảnghiện đại.
Chúng ta đã lấy một AI đa năng đã được đào tạo trước trên dữ liệu y tế liên quan, cung cấp cho nó một tập dữ liệu chuyên biệt nhỏ và dạy nó một kỹ năng mới với hiệu quả đáng kể.
Hành trình từ một mô hình chung chung đến một bộ phân loại chuyên biệt ngày càng trở nên dễ tiếp cận hơn, mở ra những khả năng thú vị cho AI trong y học và hơn thế nữa.
Tất cả thông tin đều có sẵn trongNotebook Tinh chỉnh.
Bạn có thể chạy nó với một phiên bản GPU trênVertex AI Workbench.
Muốn đưa nó vào sản xuất?
Đừng quên đón đọc bài viết sắp tới, sẽ hướng dẫn bạn cách chuyển quá trình tinh chỉnh và đánh giá sangCloud Run jobs.
Hy vọng hướng dẫn này hữu ích.
Chúc bạn lập trình vui vẻ!
1IEEE Transactions on Biomedical Engineering, vol.
63, no.
7, pp.
1455-1462, 2016






![[Tự học C++] Số dấu phẩy động(float, double,…) trong C++](https://cafedev.vn/wp-content/uploads/2019/12/cafedevn_c_develoment-100x70.jpg)

